This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 2b7dda7fa4994fc6251e7b54534e87594d93fc78
Author: Mryange <[email protected]>
AuthorDate: Tue Apr 23 09:18:34 2024 +0800

    [featrue](pipelineX) check output type in some node (#33716)
---
 be/src/pipeline/exec/aggregation_sink_operator.cpp |  9 +++++++--
 be/src/pipeline/exec/aggregation_sink_operator.h   |  2 ++
 be/src/pipeline/exec/hashjoin_build_sink.cpp       | 22 +++++++++++++++++++---
 .../exec/streaming_aggregation_operator.cpp        |  9 +++++++--
 .../pipeline/exec/streaming_aggregation_operator.h |  1 +
 be/src/pipeline/exec/union_sink_operator.cpp       |  1 +
 be/src/vec/exprs/vectorized_agg_fn.cpp             | 19 +++++++++++++++++++
 be/src/vec/exprs/vectorized_agg_fn.h               |  4 ++++
 8 files changed, 60 insertions(+), 7 deletions(-)

diff --git a/be/src/pipeline/exec/aggregation_sink_operator.cpp 
b/be/src/pipeline/exec/aggregation_sink_operator.cpp
index e29d6de2860..869685f02e5 100644
--- a/be/src/pipeline/exec/aggregation_sink_operator.cpp
+++ b/be/src/pipeline/exec/aggregation_sink_operator.cpp
@@ -631,7 +631,8 @@ AggSinkOperatorX::AggSinkOperatorX(ObjectPool* pool, int 
operator_id, const TPla
                           (tnode.__isset.conjuncts && 
!tnode.conjuncts.empty())),
           _partition_exprs(tnode.__isset.distribute_expr_lists ? 
tnode.distribute_expr_lists[0]
                                                                : 
std::vector<TExpr> {}),
-          _is_colocate(tnode.agg_node.__isset.is_colocate && 
tnode.agg_node.is_colocate) {}
+          _is_colocate(tnode.agg_node.__isset.is_colocate && 
tnode.agg_node.is_colocate),
+          _agg_fn_output_row_descriptor(descs, tnode.row_tuples, 
tnode.nullable_tuples) {}
 
 Status AggSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) {
     RETURN_IF_ERROR(DataSinkOperatorX<AggSinkLocalState>::init(tnode, state));
@@ -714,7 +715,11 @@ Status AggSinkOperatorX::prepare(RuntimeState* state) {
                     alignment_of_next_state * alignment_of_next_state;
         }
     }
-
+    // check output type
+    if (_needs_finalize) {
+        RETURN_IF_ERROR(vectorized::AggFnEvaluator::check_agg_fn_output(
+                _probe_expr_ctxs.size(), _aggregate_evaluators, 
_agg_fn_output_row_descriptor));
+    }
     return Status::OK();
 }
 
diff --git a/be/src/pipeline/exec/aggregation_sink_operator.h 
b/be/src/pipeline/exec/aggregation_sink_operator.h
index e3d8baad39c..b3ffa19d6db 100644
--- a/be/src/pipeline/exec/aggregation_sink_operator.h
+++ b/be/src/pipeline/exec/aggregation_sink_operator.h
@@ -213,6 +213,8 @@ protected:
 
     const std::vector<TExpr> _partition_exprs;
     const bool _is_colocate;
+
+    RowDescriptor _agg_fn_output_row_descriptor;
 };
 
 } // namespace pipeline
diff --git a/be/src/pipeline/exec/hashjoin_build_sink.cpp 
b/be/src/pipeline/exec/hashjoin_build_sink.cpp
index 176eaf33b1d..a0d111c63a7 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.cpp
+++ b/be/src/pipeline/exec/hashjoin_build_sink.cpp
@@ -22,6 +22,7 @@
 #include "exprs/bloom_filter_func.h"
 #include "pipeline/exec/hashjoin_probe_operator.h"
 #include "pipeline/exec/operator.h"
+#include "vec/data_types/data_type_nullable.h"
 #include "vec/exec/join/vhash_join_node.h"
 #include "vec/utils/template_helpers.hpp"
 
@@ -461,9 +462,24 @@ Status HashJoinBuildSinkOperatorX::init(const TPlanNode& 
tnode, RuntimeState* st
 
     const std::vector<TEqJoinCondition>& eq_join_conjuncts = 
tnode.hash_join_node.eq_join_conjuncts;
     for (const auto& eq_join_conjunct : eq_join_conjuncts) {
-        vectorized::VExprContextSPtr ctx;
-        
RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(eq_join_conjunct.right, 
ctx));
-        _build_expr_ctxs.push_back(ctx);
+        vectorized::VExprContextSPtr build_ctx;
+        
RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(eq_join_conjunct.right, 
build_ctx));
+        {
+            // for type check
+            vectorized::VExprContextSPtr probe_ctx;
+            
RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(eq_join_conjunct.left, 
probe_ctx));
+            auto build_side_expr_type = build_ctx->root()->data_type();
+            auto probe_side_expr_type = probe_ctx->root()->data_type();
+            if (!vectorized::make_nullable(build_side_expr_type)
+                         
->equals(*vectorized::make_nullable(probe_side_expr_type))) {
+                return Status::InternalError(
+                        "build side type {}, not match probe side type {} , 
node info "
+                        "{}",
+                        build_side_expr_type->get_name(), 
probe_side_expr_type->get_name(),
+                        this->debug_string(0));
+            }
+        }
+        _build_expr_ctxs.push_back(build_ctx);
 
         const auto vexpr = _build_expr_ctxs.back()->root();
 
diff --git a/be/src/pipeline/exec/streaming_aggregation_operator.cpp 
b/be/src/pipeline/exec/streaming_aggregation_operator.cpp
index dfcfb0ebc45..f33d799db44 100644
--- a/be/src/pipeline/exec/streaming_aggregation_operator.cpp
+++ b/be/src/pipeline/exec/streaming_aggregation_operator.cpp
@@ -1145,7 +1145,8 @@ StreamingAggOperatorX::StreamingAggOperatorX(ObjectPool* 
pool, int operator_id,
           _needs_finalize(tnode.agg_node.need_finalize),
           _is_merge(false),
           _is_first_phase(tnode.agg_node.__isset.is_first_phase && 
tnode.agg_node.is_first_phase),
-          _have_conjuncts(tnode.__isset.vconjunct && 
!tnode.vconjunct.nodes.empty()) {}
+          _have_conjuncts(tnode.__isset.vconjunct && 
!tnode.vconjunct.nodes.empty()),
+          _agg_fn_output_row_descriptor(descs, tnode.row_tuples, 
tnode.nullable_tuples) {}
 
 Status StreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* 
state) {
     RETURN_IF_ERROR(StatefulOperatorX<StreamingAggLocalState>::init(tnode, 
state));
@@ -1235,7 +1236,11 @@ Status StreamingAggOperatorX::prepare(RuntimeState* 
state) {
                     alignment_of_next_state * alignment_of_next_state;
         }
     }
-
+    // check output type
+    if (_needs_finalize) {
+        RETURN_IF_ERROR(vectorized::AggFnEvaluator::check_agg_fn_output(
+                _probe_expr_ctxs.size(), _aggregate_evaluators, 
_agg_fn_output_row_descriptor));
+    }
     return Status::OK();
 }
 
diff --git a/be/src/pipeline/exec/streaming_aggregation_operator.h 
b/be/src/pipeline/exec/streaming_aggregation_operator.h
index 2895fc63f39..caaee88b3c5 100644
--- a/be/src/pipeline/exec/streaming_aggregation_operator.h
+++ b/be/src/pipeline/exec/streaming_aggregation_operator.h
@@ -243,6 +243,7 @@ private:
     bool _can_short_circuit = false;
     std::vector<size_t> _make_nullable_keys;
     bool _have_conjuncts;
+    RowDescriptor _agg_fn_output_row_descriptor;
 };
 
 } // namespace pipeline
diff --git a/be/src/pipeline/exec/union_sink_operator.cpp 
b/be/src/pipeline/exec/union_sink_operator.cpp
index 5acf6c8e1a2..e466237a375 100644
--- a/be/src/pipeline/exec/union_sink_operator.cpp
+++ b/be/src/pipeline/exec/union_sink_operator.cpp
@@ -138,6 +138,7 @@ Status UnionSinkOperatorX::init(const TPlanNode& tnode, 
RuntimeState* state) {
 
 Status UnionSinkOperatorX::prepare(RuntimeState* state) {
     RETURN_IF_ERROR(vectorized::VExpr::prepare(_child_expr, state, 
_child_x->row_desc()));
+    RETURN_IF_ERROR(vectorized::VExpr::check_expr_output_type(_child_expr, 
_row_descriptor));
     return Status::OK();
 }
 
diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp 
b/be/src/vec/exprs/vectorized_agg_fn.cpp
index 4dfdff78205..d0fbf363727 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.cpp
+++ b/be/src/vec/exprs/vectorized_agg_fn.cpp
@@ -350,4 +350,23 @@ AggFnEvaluator::AggFnEvaluator(AggFnEvaluator& evaluator, 
RuntimeState* state)
     }
 }
 
+Status AggFnEvaluator::check_agg_fn_output(int key_size,
+                                           const 
std::vector<vectorized::AggFnEvaluator*>& agg_fn,
+                                           const RowDescriptor& 
output_row_desc) {
+    auto name_and_types = 
VectorizedUtils::create_name_and_data_types(output_row_desc);
+    for (int i = key_size, j = 0; i < name_and_types.size(); i++, j++) {
+        auto&& [name, column_type] = name_and_types[i];
+        auto agg_return_type = agg_fn[j]->function()->get_return_type();
+        if (!column_type->equals(*agg_return_type)) {
+            if (!column_type->is_nullable() || agg_return_type->is_nullable() 
||
+                !remove_nullable(column_type)->equals(*agg_return_type)) {
+                return Status::InternalError(
+                        "column_type not match data_types in agg node, 
column_type={}, "
+                        "data_types={},column name={}",
+                        column_type->get_name(), agg_return_type->get_name(), 
name);
+            }
+        }
+    }
+    return Status::OK();
+}
 } // namespace doris::vectorized
diff --git a/be/src/vec/exprs/vectorized_agg_fn.h 
b/be/src/vec/exprs/vectorized_agg_fn.h
index 546b939ddf4..7dcd1b3e02b 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.h
+++ b/be/src/vec/exprs/vectorized_agg_fn.h
@@ -97,6 +97,10 @@ public:
     bool is_merge() const { return _is_merge; }
     const VExprContextSPtrs& input_exprs_ctxs() const { return 
_input_exprs_ctxs; }
 
+    static Status check_agg_fn_output(int key_size,
+                                      const 
std::vector<vectorized::AggFnEvaluator*>& agg_fn,
+                                      const RowDescriptor& output_row_desc);
+
     void set_version(const int version) { _function->set_version(version); }
 
     AggFnEvaluator* clone(RuntimeState* state, ObjectPool* pool);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to