This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 1f30e563a7 [refactor][vectorized] refactor first/last value agg 
functions (#10661)
1f30e563a7 is described below

commit 1f30e563a708cc5c1800f8d469a43f2f011592f2
Author: zhangstar333 <[email protected]>
AuthorDate: Sat Jul 30 18:38:56 2022 +0800

    [refactor][vectorized] refactor first/last value agg functions (#10661)
    
    * refactor first and last
    [refactor][vectorized] refactor first/last value agg functions
    
    * add some change
    
    * remove first/last about always nullable
    
    * remove always nullable and register it
    
    * refactor value remove bool null flag
    
    * refactor win first last to ptr and pos
---
 be/src/vec/CMakeLists.txt                          |   1 -
 .../aggregate_function_reader.cpp                  |  22 +-
 .../aggregate_function_reader.h                    |   2 +-
 .../aggregate_function_reader_first_last.h         | 287 +++++++++++++++++++
 .../aggregate_function_simple_factory.cpp          |   5 +-
 .../aggregate_function_window.cpp                  |  89 ++++--
 .../aggregate_function_window.h                    | 312 ++++-----------------
 be/src/vec/exec/join/vhash_join_node.cpp           |   8 -
 be/src/vec/utils/template_helpers.hpp              |  22 +-
 9 files changed, 434 insertions(+), 314 deletions(-)

diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt
index 6be1a9c797..4e30835e4b 100644
--- a/be/src/vec/CMakeLists.txt
+++ b/be/src/vec/CMakeLists.txt
@@ -41,7 +41,6 @@ set(VEC_FILES
   aggregate_functions/aggregate_function_group_concat.cpp
   aggregate_functions/aggregate_function_percentile_approx.cpp
   aggregate_functions/aggregate_function_simple_factory.cpp
-  aggregate_functions/aggregate_function_java_udaf.h
   aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
   columns/column.cpp
   columns/column_array.cpp
diff --git a/be/src/vec/aggregate_functions/aggregate_function_reader.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
index 3e27f30b85..8a3bea08bd 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
@@ -44,23 +44,19 @@ void 
register_aggregate_function_replace_reader_load(AggregateFunctionSimpleFact
         factory.register_function(name + suffix, creator, nullable);
     };
 
-    register_function("replace", AGG_READER_SUFFIX, 
create_aggregate_function_first<false, true>,
-                      false);
-    register_function("replace", AGG_READER_SUFFIX, 
create_aggregate_function_first<true, true>,
-                      true);
-    register_function("replace", AGG_LOAD_SUFFIX, 
create_aggregate_function_last<false, false>,
-                      false);
-    register_function("replace", AGG_LOAD_SUFFIX, 
create_aggregate_function_last<true, false>,
-                      true);
+    register_function("replace", AGG_READER_SUFFIX, 
create_aggregate_function_first<true>, false);
+    register_function("replace", AGG_READER_SUFFIX, 
create_aggregate_function_first<true>, true);
+    register_function("replace", AGG_LOAD_SUFFIX, 
create_aggregate_function_last<false>, false);
+    register_function("replace", AGG_LOAD_SUFFIX, 
create_aggregate_function_last<false>, true);
 
     register_function("replace_if_not_null", AGG_READER_SUFFIX,
-                      create_aggregate_function_first_non_null_value<false, 
true>, false);
+                      create_aggregate_function_first_non_null_value<true>, 
false);
     register_function("replace_if_not_null", AGG_READER_SUFFIX,
-                      create_aggregate_function_first_non_null_value<true, 
true>, true);
+                      create_aggregate_function_first_non_null_value<true>, 
true);
     register_function("replace_if_not_null", AGG_LOAD_SUFFIX,
-                      create_aggregate_function_last_non_null_value<false, 
false>, false);
+                      create_aggregate_function_last_non_null_value<false>, 
false);
     register_function("replace_if_not_null", AGG_LOAD_SUFFIX,
-                      create_aggregate_function_last_non_null_value<true, 
false>, true);
+                      create_aggregate_function_last_non_null_value<false>, 
true);
 }
 
-} // namespace doris::vectorized
+} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_reader.h 
b/be/src/vec/aggregate_functions/aggregate_function_reader.h
index 86fea6f079..626c06571b 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_reader.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_reader.h
@@ -20,9 +20,9 @@
 #include "vec/aggregate_functions/aggregate_function_bitmap.h"
 #include "vec/aggregate_functions/aggregate_function_hll_union_agg.h"
 #include "vec/aggregate_functions/aggregate_function_min_max.h"
+#include "vec/aggregate_functions/aggregate_function_reader_first_last.h"
 #include "vec/aggregate_functions/aggregate_function_simple_factory.h"
 #include "vec/aggregate_functions/aggregate_function_sum.h"
-#include "vec/aggregate_functions/aggregate_function_window.h"
 
 namespace doris::vectorized {
 
diff --git 
a/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h 
b/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h
new file mode 100644
index 0000000000..4b7c1e0c98
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h
@@ -0,0 +1,287 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "factory_helpers.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/helpers.h"
+#include "vec/columns/column_nullable.h"
+#include "vec/columns/column_vector.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/data_types/data_type_string.h"
+#include "vec/io/io_helper.h"
+#include "vec/utils/template_helpers.hpp"
+
+namespace doris::vectorized {
+
+template <typename ColVecType, bool arg_is_nullable>
+struct Value {
+public:
+    bool is_null() const {
+        if (_ptr == nullptr) {
+            return true;
+        }
+        if constexpr (arg_is_nullable) {
+            return assert_cast<const 
ColumnNullable*>(_ptr)->is_null_at(_offset);
+        }
+        return false;
+    }
+
+    StringRef get_value() const {
+        if constexpr (arg_is_nullable) {
+            auto* col = assert_cast<const ColumnNullable*>(_ptr);
+            return assert_cast<const 
ColVecType&>(col->get_nested_column()).get_data_at(_offset);
+        } else {
+            return assert_cast<const ColVecType*>(_ptr)->get_data_at(_offset);
+        }
+    }
+
+    void set_value(const IColumn* column, size_t row) {
+        _ptr = column;
+        _offset = row;
+    }
+
+    void reset() {
+        _ptr = nullptr;
+        _offset = 0;
+    }
+
+protected:
+    const IColumn* _ptr = nullptr;
+    size_t _offset = 0;
+};
+
+template <typename ColVecType, bool arg_is_nullable>
+struct CopiedValue : public Value<ColVecType, arg_is_nullable> {
+public:
+    StringRef get_value() const { return _copied_value; }
+
+    bool is_null() const { return this->_ptr == nullptr; }
+
+    void set_value(const IColumn* column, size_t row) {
+        // here _ptr, maybe null at row, so call reset to set nullptr
+        // But we will use is_null() check first, others have set _ptr column 
to a meaningless address
+        // because the address have meaningless, only need it to check is 
nullptr
+        this->_ptr = (IColumn*)0x00000001;
+        if constexpr (arg_is_nullable) {
+            auto* col = assert_cast<const ColumnNullable*>(column);
+            if (col->is_null_at(row)) {
+                this->reset();
+                return;
+            } else {
+                _copied_value = assert_cast<const 
ColVecType&>(col->get_nested_column())
+                                        .get_data_at(row)
+                                        .to_string();
+            }
+        } else {
+            _copied_value = assert_cast<const 
ColVecType*>(column)->get_data_at(row).to_string();
+        }
+    }
+
+private:
+    std::string _copied_value;
+};
+
+template <typename ColVecType, bool result_is_nullable, bool arg_is_nullable, 
bool is_copy>
+struct ReaderFirstAndLastData {
+public:
+    using StoreType = std::conditional_t<is_copy, CopiedValue<ColVecType, 
arg_is_nullable>,
+                                         Value<ColVecType, arg_is_nullable>>;
+    static constexpr bool nullable = arg_is_nullable;
+
+    void reset() {
+        _data_value.reset();
+        _has_value = false;
+    }
+
+    void insert_result_into(IColumn& to) const {
+        if constexpr (result_is_nullable) {
+            if (_data_value.is_null()) { //_ptr == nullptr || null data at row
+                auto& col = assert_cast<ColumnNullable&>(to);
+                col.insert_default();
+            } else {
+                auto& col = assert_cast<ColumnNullable&>(to);
+                //get_value will never get null value
+                const StringRef& value = _data_value.get_value();
+                col.get_null_map_data().push_back(0);
+                assert_cast<ColVecType&>(col.get_nested_column())
+                        .insert_data(value.data, value.size);
+            }
+        } else {
+            const StringRef& value = _data_value.get_value();
+            assert_cast<ColVecType&>(to).insert_data(value.data, value.size);
+        }
+    }
+
+    // here not check the columns[0] is null at the row,
+    // but it is need to check in other
+    void set_value(const IColumn** columns, size_t pos) {
+        _data_value.set_value(columns[0], pos);
+        _has_value = true;
+    }
+
+    bool has_set_value() { return _has_value; }
+
+protected:
+    StoreType _data_value;
+    bool _has_value = false;
+};
+
+template <typename Data>
+struct ReaderFunctionFirstData : Data {
+    void add(int64_t row, const IColumn** columns) {
+        if (this->has_set_value()) {
+            return;
+        }
+        this->set_value(columns, row);
+    }
+    static const char* name() { return "first_value"; }
+};
+
+template <typename Data>
+struct ReaderFunctionFirstNonNullData : Data {
+    void add(int64_t row, const IColumn** columns) {
+        if (this->has_set_value()) {
+            return;
+        }
+        if constexpr (Data::nullable) {
+            const auto* nullable_column = assert_cast<const 
ColumnNullable*>(columns[0]);
+            if (nullable_column->is_null_at(row)) {
+                return;
+            }
+        }
+        this->set_value(columns, row);
+    }
+    static const char* name() { return "first_non_null_value"; }
+};
+
+template <typename Data>
+struct ReaderFunctionLastData : Data {
+    void add(int64_t row, const IColumn** columns) { this->set_value(columns, 
row); }
+    static const char* name() { return "last_value"; }
+};
+
+template <typename Data>
+struct ReaderFunctionLastNonNullData : Data {
+    void add(int64_t row, const IColumn** columns) {
+        if constexpr (Data::nullable) {
+            const auto* nullable_column = assert_cast<const 
ColumnNullable*>(columns[0]);
+            if (nullable_column->is_null_at(row)) {
+                return;
+            }
+        }
+        this->set_value(columns, row);
+    }
+
+    static const char* name() { return "last_non_null_value"; }
+};
+
+template <typename Data>
+class ReaderFunctionData final
+        : public IAggregateFunctionDataHelper<Data, ReaderFunctionData<Data>> {
+public:
+    ReaderFunctionData(const DataTypes& argument_types)
+            : IAggregateFunctionDataHelper<Data, 
ReaderFunctionData<Data>>(argument_types, {}),
+              _argument_type(argument_types[0]) {}
+
+    String get_name() const override { return Data::name(); }
+
+    DataTypePtr get_return_type() const override { return _argument_type; }
+
+    void insert_result_into(ConstAggregateDataPtr place, IColumn& to) const 
override {
+        this->data(place).insert_result_into(to);
+    }
+
+    void add(AggregateDataPtr place, const IColumn** columns, size_t row_num,
+             Arena* arena) const override {
+        this->data(place).add(row_num, columns);
+    }
+
+    void reset(AggregateDataPtr place) const override { 
this->data(place).reset(); }
+
+    void add_range_single_place(int64_t partition_start, int64_t 
partition_end, int64_t frame_start,
+                                int64_t frame_end, AggregateDataPtr place, 
const IColumn** columns,
+                                Arena* arena) const override {
+        LOG(FATAL) << "ReaderFunctionData do not support 
add_range_single_place";
+    }
+    void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena*) 
const override {
+        LOG(FATAL) << "ReaderFunctionData do not support merge";
+    }
+    void serialize(ConstAggregateDataPtr place, BufferWritable& buf) const 
override {
+        LOG(FATAL) << "ReaderFunctionData do not support serialize";
+    }
+    void deserialize(AggregateDataPtr place, BufferReadable& buf, Arena*) 
const override {
+        LOG(FATAL) << "ReaderFunctionData do not support deserialize";
+    }
+
+private:
+    DataTypePtr _argument_type;
+};
+
+template <template <typename> class AggregateFunctionTemplate, template 
<typename> class Impl,
+          bool result_is_nullable, bool arg_is_nullable, bool is_copy = false>
+static IAggregateFunction* create_function_single_value(const String& name,
+                                                        const DataTypes& 
argument_types,
+                                                        const Array& 
parameters) {
+    auto type = remove_nullable(argument_types[0]);
+    WhichDataType which(*type);
+
+#define DISPATCH(TYPE, COLUMN_TYPE)                                       \
+    if (which.idx == TypeIndex::TYPE)                                     \
+        return new AggregateFunctionTemplate<Impl<ReaderFirstAndLastData< \
+                COLUMN_TYPE, result_is_nullable, arg_is_nullable, 
is_copy>>>(argument_types);
+    TYPE_TO_COLUMN_TYPE(DISPATCH)
+#undef DISPATCH
+
+    LOG(FATAL) << "with unknowed type, failed in  create_aggregate_function_" 
<< name
+               << " and type is: " << argument_types[0]->get_name();
+    return nullptr;
+}
+
+#define CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(CREATE_FUNCTION_NAME, 
FUNCTION_DATA)            \
+    template <bool is_copy>                                                    
                   \
+    AggregateFunctionPtr CREATE_FUNCTION_NAME(const std::string& name,         
                   \
+                                              const DataTypes& argument_types, 
                   \
+                                              const Array& parameters, bool 
result_is_nullable) { \
+        const bool arg_is_nullable = argument_types[0]->is_nullable();         
                   \
+        AggregateFunctionPtr res = nullptr;                                    
                   \
+        std::visit(                                                            
                   \
+                [&](auto result_is_nullable, auto arg_is_nullable) {           
                   \
+                    res = AggregateFunctionPtr(                                
                   \
+                            create_function_single_value<ReaderFunctionData, 
FUNCTION_DATA,       \
+                                                         result_is_nullable, 
arg_is_nullable,     \
+                                                         is_copy>(name, 
argument_types,           \
+                                                                  
parameters));                   \
+                },                                                             
                   \
+                make_bool_variant(result_is_nullable), 
make_bool_variant(arg_is_nullable));       \
+        if (!res) {                                                            
                   \
+            LOG(WARNING) << " failed in  create_aggregate_function_" << name   
                   \
+                         << " and type is: " << argument_types[0]->get_name(); 
                   \
+        }                                                                      
                   \
+        return res;                                                            
                   \
+    }
+
+CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_first, 
ReaderFunctionFirstData);
+CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_first_non_null_value,
+                                          ReaderFunctionFirstNonNullData);
+CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_last, 
ReaderFunctionLastData);
+CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_last_non_null_value,
+                                          ReaderFunctionLastNonNullData);
+} // namespace doris::vectorized
\ No newline at end of file
diff --git 
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index 73779f8ffa..79d21985bc 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -39,7 +39,8 @@ void 
register_aggregate_function_HLL_union_agg(AggregateFunctionSimpleFactory& f
 void register_aggregate_function_uniq(AggregateFunctionSimpleFactory& factory);
 void register_aggregate_function_bitmap(AggregateFunctionSimpleFactory& 
factory);
 void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& 
factory);
-void 
register_aggregate_function_window_lead_lag(AggregateFunctionSimpleFactory& 
factory);
+void register_aggregate_function_window_lead_lag_first_last(
+        AggregateFunctionSimpleFactory& factory);
 void 
register_aggregate_function_stddev_variance_pop(AggregateFunctionSimpleFactory& 
factory);
 void 
register_aggregate_function_stddev_variance_samp(AggregateFunctionSimpleFactory&
 factory);
 void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory);
@@ -81,7 +82,7 @@ AggregateFunctionSimpleFactory& 
AggregateFunctionSimpleFactory::instance() {
 
         register_aggregate_function_stddev_variance_samp(instance);
         register_aggregate_function_replace_reader_load(instance);
-        register_aggregate_function_window_lead_lag(instance);
+        register_aggregate_function_window_lead_lag_first_last(instance);
         register_aggregate_function_HLL_union_agg(instance);
         register_aggregate_function_percentile_approx(instance);
     });
diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_window.cpp
index 1a342d805a..02b283ab2d 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_window.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_window.cpp
@@ -22,7 +22,7 @@
 
 #include "common/logging.h"
 #include "vec/aggregate_functions/aggregate_function_simple_factory.h"
-#include "vec/aggregate_functions/factory_helpers.h"
+#include "vec/utils/template_helpers.hpp"
 
 namespace doris::vectorized {
 
@@ -62,25 +62,59 @@ AggregateFunctionPtr create_aggregate_function_ntile(const 
std::string& name,
     return std::make_shared<WindowFunctionNTile>(argument_types, parameters);
 }
 
-template <bool is_nullable>
-AggregateFunctionPtr create_aggregate_function_lag(const std::string& name,
-                                                   const DataTypes& 
argument_types,
-                                                   const Array& parameters,
-                                                   const bool 
result_is_nullable) {
-    return AggregateFunctionPtr(
-            create_function_single_value<WindowFunctionData, 
WindowFunctionLagData, is_nullable>(
-                    name, argument_types, parameters));
+template <template <typename> class AggregateFunctionTemplate,
+          template <typename ColVecType, bool, bool> class Data, template 
<typename> class Impl,
+          bool result_is_nullable, bool arg_is_nullable>
+static IAggregateFunction* create_function_lead_lag_first_last(const String& 
name,
+                                                               const 
DataTypes& argument_types,
+                                                               const Array& 
parameters) {
+    auto type = remove_nullable(argument_types[0]);
+    WhichDataType which(*type);
+
+#define DISPATCH(TYPE, COLUMN_TYPE)           \
+    if (which.idx == TypeIndex::TYPE)         \
+        return new AggregateFunctionTemplate< \
+                Impl<Data<COLUMN_TYPE, result_is_nullable, 
arg_is_nullable>>>(argument_types);
+    TYPE_TO_BASIC_COLUMN_TYPE(DISPATCH)
+#undef DISPATCH
+
+    LOG(FATAL) << "with unknowed type, failed in  create_aggregate_function_" 
<< name
+               << " and type is: " << argument_types[0]->get_name();
+    return nullptr;
 }
 
-template <bool is_nullable>
-AggregateFunctionPtr create_aggregate_function_lead(const std::string& name,
-                                                    const DataTypes& 
argument_types,
-                                                    const Array& parameters,
-                                                    const bool 
result_is_nullable) {
-    return AggregateFunctionPtr(
-            create_function_single_value<WindowFunctionData, 
WindowFunctionLeadData, is_nullable>(
-                    name, argument_types, parameters));
-}
+#define CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(CREATE_FUNCTION_NAME, 
FUNCTION_DATA,             \
+                                                  FUNCTION_IMPL)               
                    \
+    AggregateFunctionPtr CREATE_FUNCTION_NAME(                                 
                    \
+            const std::string& name, const DataTypes& argument_types, const 
Array& parameters,     \
+            const bool result_is_nullable) {                                   
                    \
+        const bool arg_is_nullable = argument_types[0]->is_nullable();         
                    \
+        AggregateFunctionPtr res = nullptr;                                    
                    \
+                                                                               
                    \
+        std::visit(                                                            
                    \
+                [&](auto result_is_nullable, auto arg_is_nullable) {           
                    \
+                    res = AggregateFunctionPtr(                                
                    \
+                            
create_function_lead_lag_first_last<WindowFunctionData, FUNCTION_DATA, \
+                                                                FUNCTION_IMPL, 
result_is_nullable, \
+                                                                
arg_is_nullable>(                  \
+                                    name, argument_types, parameters));        
                    \
+                },                                                             
                    \
+                make_bool_variant(result_is_nullable), 
make_bool_variant(arg_is_nullable));        \
+        if (!res) {                                                            
                    \
+            LOG(WARNING) << " failed in  create_aggregate_function_" << name   
                    \
+                         << " and type is: " << argument_types[0]->get_name(); 
                    \
+        }                                                                      
                    \
+        return res;                                                            
                    \
+    }
+
+CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_lag,
 LeadLagData,
+                                          WindowFunctionLagImpl);
+CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_lead,
 LeadLagData,
+                                          WindowFunctionLeadImpl);
+CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_first,
 FirstLastData,
+                                          WindowFunctionFirstImpl);
+CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_last,
 FirstLastData,
+                                          WindowFunctionLastImpl);
 
 void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& 
factory) {
     factory.register_function("dense_rank", 
create_aggregate_function_dense_rank);
@@ -89,15 +123,16 @@ void 
register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& fac
     factory.register_function("ntile", create_aggregate_function_ntile);
 }
 
-void 
register_aggregate_function_window_lead_lag(AggregateFunctionSimpleFactory& 
factory) {
-    factory.register_function("lead", create_aggregate_function_lead<false>);
-    factory.register_function("lead", create_aggregate_function_lead<true>, 
true);
-    factory.register_function("lag", create_aggregate_function_lag<false>);
-    factory.register_function("lag", create_aggregate_function_lag<true>, 
true);
-    factory.register_function("first_value", 
create_aggregate_function_first<false, false>);
-    factory.register_function("first_value", 
create_aggregate_function_first<true, false>, true);
-    factory.register_function("last_value", 
create_aggregate_function_last<false, false>);
-    factory.register_function("last_value", 
create_aggregate_function_last<true, false>, true);
+void register_aggregate_function_window_lead_lag_first_last(
+        AggregateFunctionSimpleFactory& factory) {
+    factory.register_function("lead", create_aggregate_function_window_lead);
+    factory.register_function("lead", create_aggregate_function_window_lead, 
true);
+    factory.register_function("lag", create_aggregate_function_window_lag);
+    factory.register_function("lag", create_aggregate_function_window_lag, 
true);
+    factory.register_function("first_value", 
create_aggregate_function_window_first);
+    factory.register_function("first_value", 
create_aggregate_function_window_first, true);
+    factory.register_function("last_value", 
create_aggregate_function_window_last);
+    factory.register_function("last_value", 
create_aggregate_function_window_last, true);
 }
 
 } // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.h 
b/be/src/vec/aggregate_functions/aggregate_function_window.h
index 6e0dba239e..c0d37f8e80 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_window.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_window.h
@@ -22,6 +22,7 @@
 
 #include "factory_helpers.h"
 #include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_reader_first_last.h"
 #include "vec/aggregate_functions/helpers.h"
 #include "vec/columns/column_vector.h"
 #include "vec/data_types/data_type_decimal.h"
@@ -207,60 +208,36 @@ public:
     void deserialize(AggregateDataPtr place, BufferReadable& buf, Arena*) 
const override {}
 };
 
-struct Value {
+template <typename ColVecType, bool result_is_nullable, bool arg_is_nullable>
+struct FirstLastData
+        : public ReaderFirstAndLastData<ColVecType, result_is_nullable, 
arg_is_nullable, false> {
 public:
-    bool is_null() const { return _is_null; }
-    void set_null(bool is_null) { _is_null = is_null; }
-    StringRef get_value() const { return _ptr->get_data_at(_offset); }
-
-    void set_value(const IColumn* column, size_t row) {
-        _ptr = column;
-        _offset = row;
-    }
-    void reset() {
-        _is_null = false;
-        _ptr = nullptr;
-        _offset = 0;
-    }
-
-protected:
-    const IColumn* _ptr = nullptr;
-    size_t _offset = 0;
-    bool _is_null;
+    void set_is_null() { this->_data_value.reset(); }
 };
 
-struct CopiedValue : public Value {
+template <typename ColVecType, bool arg_is_nullable>
+struct BaseValue : public Value<ColVecType, arg_is_nullable> {
 public:
-    StringRef get_value() const { return _copied_value; }
-
-    void set_value(const IColumn* column, size_t row) {
-        _copied_value = column->get_data_at(row).to_string();
-    }
-
-private:
-    std::string _copied_value;
+    bool is_null() const { return this->_ptr == nullptr; }
+    // because _ptr pointer to first_argument or third argument, so it's 
difficult to cast ptr
+    // so here will call virtual function
+    StringRef get_value() const { return 
this->_ptr->get_data_at(this->_offset); }
 };
 
-template <typename T, bool result_is_nullable, bool is_string, typename 
StoreType = Value>
-struct LeadAndLagData {
+template <typename ColVecType, bool result_is_nullable, bool arg_is_nullable>
+struct LeadLagData {
 public:
-    bool has_init() const { return _is_init; }
-
-    static constexpr bool nullable = result_is_nullable;
-
-    void set_null_if_need() {
-        if (!_has_value) {
-            this->set_is_null();
-        }
-    }
-
     void reset() {
         _data_value.reset();
         _default_value.reset();
-        _is_init = false;
-        _has_value = false;
+        _is_inited = false;
     }
 
+    bool default_is_null() { return _default_value.is_null(); }
+
+    // here _ptr pointer default column from third
+    void set_value_from_default() { this->_data_value = _default_value; }
+
     void insert_result_into(IColumn& to) const {
         if constexpr (result_is_nullable) {
             if (_data_value.is_null()) {
@@ -277,53 +254,47 @@ public:
         }
     }
 
+    void set_is_null() { this->_data_value.reset(); }
+
     void set_value(const IColumn** columns, size_t pos) {
-        if (columns[0]->is_nullable() &&
-            assert_cast<const ColumnNullable*>(columns[0])->is_null_at(pos)) {
-            _data_value.set_null(true);
-        } else {
-            _data_value.set_value(columns[0], pos);
-            _data_value.set_null(false);
+        if constexpr (arg_is_nullable) {
+            if (assert_cast<const 
ColumnNullable*>(columns[0])->is_null_at(pos)) {
+                // ptr == nullptr means nullable
+                _data_value.reset();
+                return;
+            }
         }
-        _has_value = true;
+        // here ptr is pointer to nullable column or not null column from first
+        _data_value.set_value(columns[0], pos);
     }
 
-    bool defualt_is_null() { return _default_value.is_null(); }
-
-    void set_is_null() { _data_value.set_null(true); }
-
-    void set_value_from_default() { _data_value = _default_value; }
-
-    bool has_set_value() { return _has_value; }
-
     void check_default(const IColumn* column) {
-        if (!has_init()) {
+        if (!_is_inited) {
             if (is_column_nullable(*column)) {
                 const auto* nullable_column = assert_cast<const 
ColumnNullable*>(column);
                 if (nullable_column->is_null_at(0)) {
-                    _default_value.set_null(true);
+                    _default_value.reset();
                 }
             } else {
                 _default_value.set_value(column, 0);
             }
-            _is_init = true;
+            _is_inited = true;
         }
     }
 
 private:
-    StoreType _data_value;
-    StoreType _default_value;
-    bool _has_value = false;
-    bool _is_init = false;
+    BaseValue<ColVecType, arg_is_nullable> _data_value;
+    BaseValue<ColVecType, arg_is_nullable> _default_value;
+    bool _is_inited = false;
 };
 
 template <typename Data>
-struct WindowFunctionLeadData : Data {
-    void add_range_single_place(int64_t partition_start, int64_t 
partition_end, size_t frame_start,
-                                size_t frame_end, const IColumn** columns) {
+struct WindowFunctionLeadImpl : Data {
+    void add_range_single_place(int64_t partition_start, int64_t 
partition_end, int64_t frame_start,
+                                int64_t frame_end, const IColumn** columns) {
         this->check_default(columns[2]);
         if (frame_end > partition_end) { //output default value, win end is 
under partition
-            if (this->defualt_is_null()) {
+            if (this->default_is_null()) {
                 this->set_is_null();
             } else {
                 this->set_value_from_default();
@@ -332,19 +303,17 @@ struct WindowFunctionLeadData : Data {
         }
         this->set_value(columns, frame_end - 1);
     }
-    void add(int64_t row, const IColumn** columns) {
-        LOG(FATAL) << "WindowFunctionLeadData do not support add";
-    }
+
     static const char* name() { return "lead"; }
 };
 
 template <typename Data>
-struct WindowFunctionLagData : Data {
+struct WindowFunctionLagImpl : Data {
     void add_range_single_place(int64_t partition_start, int64_t 
partition_end, int64_t frame_start,
                                 int64_t frame_end, const IColumn** columns) {
         this->check_default(columns[2]);
         if (partition_start >= frame_end) { //[unbound preceding(0), offset 
preceding(-123)]
-            if (this->defualt_is_null()) {  // win start is beyond partition
+            if (this->default_is_null()) {  // win start is beyond partition
                 this->set_is_null();
             } else {
                 this->set_value_from_default();
@@ -353,14 +322,15 @@ struct WindowFunctionLagData : Data {
         }
         this->set_value(columns, frame_end - 1);
     }
-    void add(int64_t row, const IColumn** columns) {
-        LOG(FATAL) << "WindowFunctionLagData do not support add";
-    }
+
     static const char* name() { return "lag"; }
 };
 
+// TODO: first_value && last_value in some corner case will be core,
+// if need to simply change it, should set them to always nullable insert into 
null value, and register in cpp maybe be change
+// But it's may be another better way to handle it
 template <typename Data>
-struct WindowFunctionFirstData : Data {
+struct WindowFunctionFirstImpl : Data {
     void add_range_single_place(int64_t partition_start, int64_t 
partition_end, int64_t frame_start,
                                 int64_t frame_end, const IColumn** columns) {
         if (this->has_set_value()) {
@@ -374,61 +344,12 @@ struct WindowFunctionFirstData : Data {
         frame_start = std::max<int64_t>(frame_start, partition_start);
         this->set_value(columns, frame_start);
     }
-    void add(int64_t row, const IColumn** columns) {
-        if (this->has_set_value()) {
-            return;
-        }
-        this->set_value(columns, row);
-    }
-    static const char* name() { return "first_value"; }
-};
-
-template <typename Data>
-struct WindowFunctionFirstNonNullData : Data {
-    void add_range_single_place(int64_t partition_start, int64_t 
partition_end, int64_t frame_start,
-                                int64_t frame_end, const IColumn** columns) {
-        if (this->has_set_value()) {
-            return;
-        }
-        if (frame_start < frame_end &&
-            frame_end <= partition_start) { //rewrite last_value when under 
partition
-            this->set_is_null();            //so no need more judge
-            return;
-        }
-        frame_start = std::max<int64_t>(frame_start, partition_start);
-        frame_end = std::min<int64_t>(frame_end, partition_end);
-        if constexpr (Data::nullable) {
-            this->set_null_if_need();
-            const auto* nullable_column = assert_cast<const 
ColumnNullable*>(columns[0]);
-            for (int i = frame_start; i < frame_end; i++) {
-                if (!nullable_column->is_null_at(i)) {
-                    this->set_value(columns, i);
-                    return;
-                }
-            }
-        } else {
-            this->set_value(columns, frame_start);
-        }
-    }
 
-    void add(int64_t row, const IColumn** columns) {
-        if (this->has_set_value()) {
-            return;
-        }
-        if constexpr (Data::nullable) {
-            this->set_null_if_need();
-            const auto* nullable_column = assert_cast<const 
ColumnNullable*>(columns[0]);
-            if (nullable_column->is_null_at(row)) {
-                return;
-            }
-        }
-        this->set_value(columns, row);
-    }
-    static const char* name() { return "first_non_null_value"; }
+    static const char* name() { return "first_value"; }
 };
 
 template <typename Data>
-struct WindowFunctionLastData : Data {
+struct WindowFunctionLastImpl : Data {
     void add_range_single_place(int64_t partition_start, int64_t 
partition_end, int64_t frame_start,
                                 int64_t frame_end, const IColumn** columns) {
         if ((frame_start < frame_end) &&
@@ -440,48 +361,8 @@ struct WindowFunctionLastData : Data {
         frame_end = std::min<int64_t>(frame_end, partition_end);
         this->set_value(columns, frame_end - 1);
     }
-    void add(int64_t row, const IColumn** columns) { this->set_value(columns, 
row); }
-    static const char* name() { return "last_value"; }
-};
-
-template <typename Data>
-struct WindowFunctionLastNonNullData : Data {
-    void add_range_single_place(int64_t partition_start, int64_t 
partition_end, int64_t frame_start,
-                                int64_t frame_end, const IColumn** columns) {
-        if ((frame_start < frame_end) &&
-            ((frame_end <= partition_start) ||
-             (frame_start >= partition_end))) { //beyond or under partition, 
set null
-            this->set_is_null();
-            return;
-        }
-        frame_start = std::max<int64_t>(frame_start, partition_start);
-        frame_end = std::min<int64_t>(frame_end, partition_end);
-        if constexpr (Data::nullable) {
-            this->set_null_if_need();
-            const auto* nullable_column = assert_cast<const 
ColumnNullable*>(columns[0]);
-            for (int i = frame_end - 1; i >= frame_start; i--) {
-                if (!nullable_column->is_null_at(i)) {
-                    this->set_value(columns, i);
-                    return;
-                }
-            }
-        } else {
-            this->set_value(columns, frame_end - 1);
-        }
-    }
-
-    void add(int64_t row, const IColumn** columns) {
-        if constexpr (Data::nullable) {
-            this->set_null_if_need();
-            const auto* nullable_column = assert_cast<const 
ColumnNullable*>(columns[0]);
-            if (nullable_column->is_null_at(row)) {
-                return;
-            }
-        }
-        this->set_value(columns, row);
-    }
 
-    static const char* name() { return "last_non_null_value"; }
+    static const char* name() { return "last_value"; }
 };
 
 template <typename Data>
@@ -493,6 +374,7 @@ public:
               _argument_type(argument_types[0]) {}
 
     String get_name() const override { return Data::name(); }
+
     DataTypePtr get_return_type() const override { return _argument_type; }
 
     void add_range_single_place(int64_t partition_start, int64_t 
partition_end, int64_t frame_start,
@@ -510,104 +392,20 @@ public:
 
     void add(AggregateDataPtr place, const IColumn** columns, size_t row_num,
              Arena* arena) const override {
-        this->data(place).add(row_num, columns);
+        LOG(FATAL) << "WindowFunctionLeadLagData do not support add";
     }
     void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena*) 
const override {
-        LOG(FATAL) << "WindowFunctionData do not support merge";
+        LOG(FATAL) << "WindowFunctionLeadLagData do not support merge";
     }
     void serialize(ConstAggregateDataPtr place, BufferWritable& buf) const 
override {
-        LOG(FATAL) << "WindowFunctionData do not support serialize";
+        LOG(FATAL) << "WindowFunctionLeadLagData do not support serialize";
     }
     void deserialize(AggregateDataPtr place, BufferReadable& buf, Arena*) 
const override {
-        LOG(FATAL) << "WindowFunctionData do not support deserialize";
+        LOG(FATAL) << "WindowFunctionLeadLagData do not support deserialize";
     }
 
 private:
     DataTypePtr _argument_type;
 };
 
-template <template <typename> class AggregateFunctionTemplate, template 
<typename> class Data,
-          bool result_is_nullable, bool is_copy = false>
-static IAggregateFunction* create_function_single_value(const String& name,
-                                                        const DataTypes& 
argument_types,
-                                                        const Array& 
parameters) {
-    using StoreType = std::conditional_t<is_copy, CopiedValue, Value>;
-
-    assert_arity_at_most<3>(name, argument_types);
-
-    auto type = remove_nullable(argument_types[0]);
-    WhichDataType which(*type);
-
-#define DISPATCH(TYPE)                        \
-    if (which.idx == TypeIndex::TYPE)         \
-        return new AggregateFunctionTemplate< \
-                Data<LeadAndLagData<TYPE, result_is_nullable, false, 
StoreType>>>(argument_types);
-    FOR_NUMERIC_TYPES(DISPATCH)
-#undef DISPATCH
-
-    if (which.is_decimal()) {
-        return new AggregateFunctionTemplate<
-                Data<LeadAndLagData<Int128, result_is_nullable, false, 
StoreType>>>(argument_types);
-    }
-    if (which.is_date_or_datetime()) {
-        return new AggregateFunctionTemplate<
-                Data<LeadAndLagData<Int64, result_is_nullable, false, 
StoreType>>>(argument_types);
-    }
-    if (which.is_date_v2()) {
-        return new AggregateFunctionTemplate<
-                Data<LeadAndLagData<UInt32, result_is_nullable, false, 
StoreType>>>(argument_types);
-    }
-    if (which.is_date_time_v2()) {
-        return new AggregateFunctionTemplate<
-                Data<LeadAndLagData<UInt64, result_is_nullable, false, 
StoreType>>>(argument_types);
-    }
-    if (which.is_string_or_fixed_string()) {
-        return new AggregateFunctionTemplate<
-                Data<LeadAndLagData<StringRef, result_is_nullable, true, 
StoreType>>>(
-                argument_types);
-    }
-    DCHECK(false) << "with unknowed type, failed in  
create_aggregate_function_" << name;
-    return nullptr;
-}
-
-template <bool is_nullable, bool is_copy>
-AggregateFunctionPtr create_aggregate_function_first(const std::string& name,
-                                                     const DataTypes& 
argument_types,
-                                                     const Array& parameters,
-                                                     bool result_is_nullable) {
-    return AggregateFunctionPtr(
-            create_function_single_value<WindowFunctionData, 
WindowFunctionFirstData, is_nullable,
-                                         is_copy>(name, argument_types, 
parameters));
-}
-
-template <bool is_nullable, bool is_copy>
-AggregateFunctionPtr create_aggregate_function_first_non_null_value(const 
std::string& name,
-                                                                    const 
DataTypes& argument_types,
-                                                                    const 
Array& parameters,
-                                                                    bool 
result_is_nullable) {
-    return AggregateFunctionPtr(
-            create_function_single_value<WindowFunctionData, 
WindowFunctionFirstNonNullData,
-                                         is_nullable, is_copy>(name, 
argument_types, parameters));
-}
-
-template <bool is_nullable, bool is_copy>
-AggregateFunctionPtr create_aggregate_function_last(const std::string& name,
-                                                    const DataTypes& 
argument_types,
-                                                    const Array& parameters,
-                                                    bool result_is_nullable) {
-    return AggregateFunctionPtr(
-            create_function_single_value<WindowFunctionData, 
WindowFunctionLastData, is_nullable,
-                                         is_copy>(name, argument_types, 
parameters));
-}
-
-template <bool is_nullable, bool is_copy>
-AggregateFunctionPtr create_aggregate_function_last_non_null_value(const 
std::string& name,
-                                                                   const 
DataTypes& argument_types,
-                                                                   const 
Array& parameters,
-                                                                   bool 
result_is_nullable) {
-    return AggregateFunctionPtr(
-            create_function_single_value<WindowFunctionData, 
WindowFunctionLastNonNullData,
-                                         is_nullable, is_copy>(name, 
argument_types, parameters));
-}
-
 } // namespace doris::vectorized
diff --git a/be/src/vec/exec/join/vhash_join_node.cpp 
b/be/src/vec/exec/join/vhash_join_node.cpp
index c88fbbb683..6cbdcfa53f 100644
--- a/be/src/vec/exec/join/vhash_join_node.cpp
+++ b/be/src/vec/exec/join/vhash_join_node.cpp
@@ -31,14 +31,6 @@
 
 namespace doris::vectorized {
 
-std::variant<std::false_type, std::true_type> static inline 
make_bool_variant(bool condition) {
-    if (condition) {
-        return std::true_type {};
-    } else {
-        return std::false_type {};
-    }
-}
-
 using ProfileCounter = RuntimeProfile::Counter;
 template <class HashTableContext>
 struct ProcessHashTableBuild {
diff --git a/be/src/vec/utils/template_helpers.hpp 
b/be/src/vec/utils/template_helpers.hpp
index ebf822513b..187ec7accc 100644
--- a/be/src/vec/utils/template_helpers.hpp
+++ b/be/src/vec/utils/template_helpers.hpp
@@ -18,6 +18,7 @@
 #pragma once
 
 #include <limits>
+#include <variant>
 
 #include "http/http_status.h"
 #include "vec/aggregate_functions/aggregate_function.h"
@@ -53,11 +54,14 @@
     M(BitMap, ColumnBitmap)            \
     M(HLL, ColumnHLL)
 
-#define TYPE_TO_COLUMN_TYPE(M)     \
-    NUMERIC_TYPE_TO_COLUMN_TYPE(M) \
-    DECIMAL_TYPE_TO_COLUMN_TYPE(M) \
-    STRING_TYPE_TO_COLUMN_TYPE(M)  \
-    TIME_TYPE_TO_COLUMN_TYPE(M)    \
+#define TYPE_TO_BASIC_COLUMN_TYPE(M) \
+    NUMERIC_TYPE_TO_COLUMN_TYPE(M)   \
+    DECIMAL_TYPE_TO_COLUMN_TYPE(M)   \
+    STRING_TYPE_TO_COLUMN_TYPE(M)    \
+    TIME_TYPE_TO_COLUMN_TYPE(M)
+
+#define TYPE_TO_COLUMN_TYPE(M)   \
+    TYPE_TO_BASIC_COLUMN_TYPE(M) \
     COMPLEX_TYPE_TO_COLUMN_TYPE(M)
 
 namespace doris::vectorized {
@@ -150,4 +154,12 @@ template <template <bool, bool, bool> typename Reducer>
 using constexpr_3_bool_match =
         constexpr_3_loop_match<bool, false, true, Reducer, 
constexpr_2_bool_match>;
 
+std::variant<std::false_type, std::true_type> static inline 
make_bool_variant(bool condition) {
+    if (condition) {
+        return std::true_type {};
+    } else {
+        return std::false_type {};
+    }
+}
+
 } // namespace  doris::vectorized


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to