This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 2b7dda7fa4994fc6251e7b54534e87594d93fc78 Author: Mryange <[email protected]> AuthorDate: Tue Apr 23 09:18:34 2024 +0800 [featrue](pipelineX) check output type in some node (#33716) --- be/src/pipeline/exec/aggregation_sink_operator.cpp | 9 +++++++-- be/src/pipeline/exec/aggregation_sink_operator.h | 2 ++ be/src/pipeline/exec/hashjoin_build_sink.cpp | 22 +++++++++++++++++++--- .../exec/streaming_aggregation_operator.cpp | 9 +++++++-- .../pipeline/exec/streaming_aggregation_operator.h | 1 + be/src/pipeline/exec/union_sink_operator.cpp | 1 + be/src/vec/exprs/vectorized_agg_fn.cpp | 19 +++++++++++++++++++ be/src/vec/exprs/vectorized_agg_fn.h | 4 ++++ 8 files changed, 60 insertions(+), 7 deletions(-) diff --git a/be/src/pipeline/exec/aggregation_sink_operator.cpp b/be/src/pipeline/exec/aggregation_sink_operator.cpp index e29d6de2860..869685f02e5 100644 --- a/be/src/pipeline/exec/aggregation_sink_operator.cpp +++ b/be/src/pipeline/exec/aggregation_sink_operator.cpp @@ -631,7 +631,8 @@ AggSinkOperatorX::AggSinkOperatorX(ObjectPool* pool, int operator_id, const TPla (tnode.__isset.conjuncts && !tnode.conjuncts.empty())), _partition_exprs(tnode.__isset.distribute_expr_lists ? tnode.distribute_expr_lists[0] : std::vector<TExpr> {}), - _is_colocate(tnode.agg_node.__isset.is_colocate && tnode.agg_node.is_colocate) {} + _is_colocate(tnode.agg_node.__isset.is_colocate && tnode.agg_node.is_colocate), + _agg_fn_output_row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples) {} Status AggSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) { RETURN_IF_ERROR(DataSinkOperatorX<AggSinkLocalState>::init(tnode, state)); @@ -714,7 +715,11 @@ Status AggSinkOperatorX::prepare(RuntimeState* state) { alignment_of_next_state * alignment_of_next_state; } } - + // check output type + if (_needs_finalize) { + RETURN_IF_ERROR(vectorized::AggFnEvaluator::check_agg_fn_output( + _probe_expr_ctxs.size(), _aggregate_evaluators, _agg_fn_output_row_descriptor)); + } return Status::OK(); } diff --git a/be/src/pipeline/exec/aggregation_sink_operator.h b/be/src/pipeline/exec/aggregation_sink_operator.h index e3d8baad39c..b3ffa19d6db 100644 --- a/be/src/pipeline/exec/aggregation_sink_operator.h +++ b/be/src/pipeline/exec/aggregation_sink_operator.h @@ -213,6 +213,8 @@ protected: const std::vector<TExpr> _partition_exprs; const bool _is_colocate; + + RowDescriptor _agg_fn_output_row_descriptor; }; } // namespace pipeline diff --git a/be/src/pipeline/exec/hashjoin_build_sink.cpp b/be/src/pipeline/exec/hashjoin_build_sink.cpp index 176eaf33b1d..a0d111c63a7 100644 --- a/be/src/pipeline/exec/hashjoin_build_sink.cpp +++ b/be/src/pipeline/exec/hashjoin_build_sink.cpp @@ -22,6 +22,7 @@ #include "exprs/bloom_filter_func.h" #include "pipeline/exec/hashjoin_probe_operator.h" #include "pipeline/exec/operator.h" +#include "vec/data_types/data_type_nullable.h" #include "vec/exec/join/vhash_join_node.h" #include "vec/utils/template_helpers.hpp" @@ -461,9 +462,24 @@ Status HashJoinBuildSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* st const std::vector<TEqJoinCondition>& eq_join_conjuncts = tnode.hash_join_node.eq_join_conjuncts; for (const auto& eq_join_conjunct : eq_join_conjuncts) { - vectorized::VExprContextSPtr ctx; - RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(eq_join_conjunct.right, ctx)); - _build_expr_ctxs.push_back(ctx); + vectorized::VExprContextSPtr build_ctx; + RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(eq_join_conjunct.right, build_ctx)); + { + // for type check + vectorized::VExprContextSPtr probe_ctx; + RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(eq_join_conjunct.left, probe_ctx)); + auto build_side_expr_type = build_ctx->root()->data_type(); + auto probe_side_expr_type = probe_ctx->root()->data_type(); + if (!vectorized::make_nullable(build_side_expr_type) + ->equals(*vectorized::make_nullable(probe_side_expr_type))) { + return Status::InternalError( + "build side type {}, not match probe side type {} , node info " + "{}", + build_side_expr_type->get_name(), probe_side_expr_type->get_name(), + this->debug_string(0)); + } + } + _build_expr_ctxs.push_back(build_ctx); const auto vexpr = _build_expr_ctxs.back()->root(); diff --git a/be/src/pipeline/exec/streaming_aggregation_operator.cpp b/be/src/pipeline/exec/streaming_aggregation_operator.cpp index dfcfb0ebc45..f33d799db44 100644 --- a/be/src/pipeline/exec/streaming_aggregation_operator.cpp +++ b/be/src/pipeline/exec/streaming_aggregation_operator.cpp @@ -1145,7 +1145,8 @@ StreamingAggOperatorX::StreamingAggOperatorX(ObjectPool* pool, int operator_id, _needs_finalize(tnode.agg_node.need_finalize), _is_merge(false), _is_first_phase(tnode.agg_node.__isset.is_first_phase && tnode.agg_node.is_first_phase), - _have_conjuncts(tnode.__isset.vconjunct && !tnode.vconjunct.nodes.empty()) {} + _have_conjuncts(tnode.__isset.vconjunct && !tnode.vconjunct.nodes.empty()), + _agg_fn_output_row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples) {} Status StreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* state) { RETURN_IF_ERROR(StatefulOperatorX<StreamingAggLocalState>::init(tnode, state)); @@ -1235,7 +1236,11 @@ Status StreamingAggOperatorX::prepare(RuntimeState* state) { alignment_of_next_state * alignment_of_next_state; } } - + // check output type + if (_needs_finalize) { + RETURN_IF_ERROR(vectorized::AggFnEvaluator::check_agg_fn_output( + _probe_expr_ctxs.size(), _aggregate_evaluators, _agg_fn_output_row_descriptor)); + } return Status::OK(); } diff --git a/be/src/pipeline/exec/streaming_aggregation_operator.h b/be/src/pipeline/exec/streaming_aggregation_operator.h index 2895fc63f39..caaee88b3c5 100644 --- a/be/src/pipeline/exec/streaming_aggregation_operator.h +++ b/be/src/pipeline/exec/streaming_aggregation_operator.h @@ -243,6 +243,7 @@ private: bool _can_short_circuit = false; std::vector<size_t> _make_nullable_keys; bool _have_conjuncts; + RowDescriptor _agg_fn_output_row_descriptor; }; } // namespace pipeline diff --git a/be/src/pipeline/exec/union_sink_operator.cpp b/be/src/pipeline/exec/union_sink_operator.cpp index 5acf6c8e1a2..e466237a375 100644 --- a/be/src/pipeline/exec/union_sink_operator.cpp +++ b/be/src/pipeline/exec/union_sink_operator.cpp @@ -138,6 +138,7 @@ Status UnionSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) { Status UnionSinkOperatorX::prepare(RuntimeState* state) { RETURN_IF_ERROR(vectorized::VExpr::prepare(_child_expr, state, _child_x->row_desc())); + RETURN_IF_ERROR(vectorized::VExpr::check_expr_output_type(_child_expr, _row_descriptor)); return Status::OK(); } diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp b/be/src/vec/exprs/vectorized_agg_fn.cpp index 4dfdff78205..d0fbf363727 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.cpp +++ b/be/src/vec/exprs/vectorized_agg_fn.cpp @@ -350,4 +350,23 @@ AggFnEvaluator::AggFnEvaluator(AggFnEvaluator& evaluator, RuntimeState* state) } } +Status AggFnEvaluator::check_agg_fn_output(int key_size, + const std::vector<vectorized::AggFnEvaluator*>& agg_fn, + const RowDescriptor& output_row_desc) { + auto name_and_types = VectorizedUtils::create_name_and_data_types(output_row_desc); + for (int i = key_size, j = 0; i < name_and_types.size(); i++, j++) { + auto&& [name, column_type] = name_and_types[i]; + auto agg_return_type = agg_fn[j]->function()->get_return_type(); + if (!column_type->equals(*agg_return_type)) { + if (!column_type->is_nullable() || agg_return_type->is_nullable() || + !remove_nullable(column_type)->equals(*agg_return_type)) { + return Status::InternalError( + "column_type not match data_types in agg node, column_type={}, " + "data_types={},column name={}", + column_type->get_name(), agg_return_type->get_name(), name); + } + } + } + return Status::OK(); +} } // namespace doris::vectorized diff --git a/be/src/vec/exprs/vectorized_agg_fn.h b/be/src/vec/exprs/vectorized_agg_fn.h index 546b939ddf4..7dcd1b3e02b 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.h +++ b/be/src/vec/exprs/vectorized_agg_fn.h @@ -97,6 +97,10 @@ public: bool is_merge() const { return _is_merge; } const VExprContextSPtrs& input_exprs_ctxs() const { return _input_exprs_ctxs; } + static Status check_agg_fn_output(int key_size, + const std::vector<vectorized::AggFnEvaluator*>& agg_fn, + const RowDescriptor& output_row_desc); + void set_version(const int version) { _function->set_version(version); } AggFnEvaluator* clone(RuntimeState* state, ObjectPool* pool); --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
