This is an automated email from the ASF dual-hosted git repository.

gabriellee pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 21aea76333 [pipelineX](feature) support assert rows num operator 
(#23857)
21aea76333 is described below

commit 21aea76333f3e818c667e622a1d06fe3effa7785
Author: Gabriel <[email protected]>
AuthorDate: Mon Sep 4 18:17:26 2023 +0800

    [pipelineX](feature) support assert rows num operator (#23857)
---
 be/src/pipeline/exec/assert_num_rows_operator.cpp  | 87 ++++++++++++++++++++++
 be/src/pipeline/exec/assert_num_rows_operator.h    | 30 +++++++-
 be/src/pipeline/exec/hashjoin_probe_operator.cpp   | 65 +++++-----------
 be/src/pipeline/exec/hashjoin_probe_operator.h     | 13 ++--
 be/src/pipeline/exec/join_probe_operator.h         |  6 +-
 .../exec/nested_loop_join_probe_operator.cpp       | 28 -------
 .../exec/nested_loop_join_probe_operator.h         |  9 +--
 be/src/pipeline/exec/repeat_operator.cpp           | 74 +++++++-----------
 be/src/pipeline/exec/repeat_operator.h             | 22 +++---
 be/src/pipeline/pipeline_x/operator.cpp            | 50 ++++++++++++-
 be/src/pipeline/pipeline_x/operator.h              | 49 ++++++++++--
 .../pipeline_x/pipeline_x_fragment_context.cpp     |  6 ++
 12 files changed, 279 insertions(+), 160 deletions(-)

diff --git a/be/src/pipeline/exec/assert_num_rows_operator.cpp 
b/be/src/pipeline/exec/assert_num_rows_operator.cpp
new file mode 100644
index 0000000000..315ecd3e51
--- /dev/null
+++ b/be/src/pipeline/exec/assert_num_rows_operator.cpp
@@ -0,0 +1,87 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "assert_num_rows_operator.h"
+
+namespace doris::pipeline {
+
+OperatorPtr AssertNumRowsOperatorBuilder::build_operator() {
+    return std::make_shared<AssertNumRowsOperator>(this, _node);
+}
+
+AssertNumRowsOperatorX::AssertNumRowsOperatorX(ObjectPool* pool, const 
TPlanNode& tnode,
+                                               const DescriptorTbl& descs)
+        : StreamingOperatorX<AssertNumRowsLocalState>(pool, tnode, descs),
+          _desired_num_rows(tnode.assert_num_rows_node.desired_num_rows),
+          _subquery_string(tnode.assert_num_rows_node.subquery_string) {
+    if (tnode.assert_num_rows_node.__isset.assertion) {
+        _assertion = tnode.assert_num_rows_node.assertion;
+    } else {
+        _assertion = TAssertion::LE; // just compatible for the previous code
+    }
+}
+
+Status AssertNumRowsOperatorX::pull(doris::RuntimeState* state, 
vectorized::Block* block,
+                                    SourceState& source_state) {
+    auto& local_state = 
state->get_local_state(id())->cast<AssertNumRowsLocalState>();
+    local_state.add_num_rows_returned(block->rows());
+    int64_t num_rows_returned = local_state.num_rows_returned();
+    bool assert_res = false;
+    switch (_assertion) {
+    case TAssertion::EQ:
+        assert_res = num_rows_returned == _desired_num_rows;
+        break;
+    case TAssertion::NE:
+        assert_res = num_rows_returned != _desired_num_rows;
+        break;
+    case TAssertion::LT:
+        assert_res = num_rows_returned < _desired_num_rows;
+        break;
+    case TAssertion::LE:
+        assert_res = num_rows_returned <= _desired_num_rows;
+        break;
+    case TAssertion::GT:
+        assert_res = num_rows_returned > _desired_num_rows;
+        break;
+    case TAssertion::GE:
+        assert_res = num_rows_returned >= _desired_num_rows;
+        break;
+    default:
+        break;
+    }
+
+    if (!assert_res) {
+        auto to_string_lambda = [](TAssertion::type assertion) {
+            std::map<int, const char*>::const_iterator it =
+                    _TAssertion_VALUES_TO_NAMES.find(assertion);
+
+            if (it == _TAggregationOp_VALUES_TO_NAMES.end()) {
+                return "NULL";
+            } else {
+                return it->second;
+            }
+        };
+        LOG(INFO) << "Expected " << to_string_lambda(_assertion) << " " << 
_desired_num_rows
+                  << " to be returned by expression " << _subquery_string;
+        return Status::Cancelled("Expected {} {} to be returned by expression 
{}",
+                                 to_string_lambda(_assertion), 
_desired_num_rows, _subquery_string);
+    }
+    COUNTER_SET(local_state.rows_returned_counter(), 
local_state.num_rows_returned());
+    return Status::OK();
+}
+
+} // namespace doris::pipeline
diff --git a/be/src/pipeline/exec/assert_num_rows_operator.h 
b/be/src/pipeline/exec/assert_num_rows_operator.h
index 7f1371e521..2c271be3d6 100644
--- a/be/src/pipeline/exec/assert_num_rows_operator.h
+++ b/be/src/pipeline/exec/assert_num_rows_operator.h
@@ -18,6 +18,7 @@
 #pragma once
 
 #include "operator.h"
+#include "pipeline/pipeline_x/operator.h"
 #include "vec/exec/vassert_num_rows_node.h"
 
 namespace doris {
@@ -38,9 +39,30 @@ public:
             : StreamingOperator(operator_builder, node) {}
 };
 
-OperatorPtr AssertNumRowsOperatorBuilder::build_operator() {
-    return std::make_shared<AssertNumRowsOperator>(this, _node);
-}
+class AssertNumRowsLocalState final : public 
PipelineXLocalState<FakeDependency> {
+public:
+    ENABLE_FACTORY_CREATOR(AssertNumRowsLocalState);
+
+    AssertNumRowsLocalState(RuntimeState* state, OperatorXBase* parent)
+            : PipelineXLocalState<FakeDependency>(state, parent) {}
+    ~AssertNumRowsLocalState() = default;
+};
+
+class AssertNumRowsOperatorX final : public 
StreamingOperatorX<AssertNumRowsLocalState> {
+public:
+    AssertNumRowsOperatorX(ObjectPool* pool, const TPlanNode& tnode, const 
DescriptorTbl& descs);
+
+    Status pull(RuntimeState* state, vectorized::Block* block, SourceState& 
source_state) override;
+
+    [[nodiscard]] bool is_source() const override { return false; }
+
+private:
+    friend class AssertNumRowsLocalState;
+
+    int64_t _desired_num_rows;
+    const std::string _subquery_string;
+    TAssertion::type _assertion;
+};
 
 } // namespace pipeline
-} // namespace doris
\ No newline at end of file
+} // namespace doris
diff --git a/be/src/pipeline/exec/hashjoin_probe_operator.cpp 
b/be/src/pipeline/exec/hashjoin_probe_operator.cpp
index dd9d2451de..88bbcfc631 100644
--- a/be/src/pipeline/exec/hashjoin_probe_operator.cpp
+++ b/be/src/pipeline/exec/hashjoin_probe_operator.cpp
@@ -75,18 +75,21 @@ Status HashJoinProbeLocalState::close(RuntimeState* state) {
     if (_closed) {
         return Status::OK();
     }
-    std::visit(vectorized::Overload {[&](std::monostate&) {},
-                                     [&](auto&& process_hashtable_ctx) {
-                                         if (process_hashtable_ctx._arena) {
-                                             
process_hashtable_ctx._arena.reset();
-                                         }
-
-                                         if 
(process_hashtable_ctx._serialize_key_arena) {
-                                             
process_hashtable_ctx._serialize_key_arena.reset();
-                                             
process_hashtable_ctx._serialized_key_buffer_size = 0;
-                                         }
-                                     }},
-               *_process_hashtable_ctx_variants);
+    if (_process_hashtable_ctx_variants) {
+        std::visit(vectorized::Overload {[&](std::monostate&) {},
+                                         [&](auto&& process_hashtable_ctx) {
+                                             if (process_hashtable_ctx._arena) 
{
+                                                 
process_hashtable_ctx._arena.reset();
+                                             }
+
+                                             if 
(process_hashtable_ctx._serialize_key_arena) {
+                                                 
process_hashtable_ctx._serialize_key_arena.reset();
+                                                 
process_hashtable_ctx._serialized_key_buffer_size =
+                                                         0;
+                                             }
+                                         }},
+                   *_process_hashtable_ctx_variants);
+    }
     _shared_state->arena = nullptr;
     _shared_state->hash_table_variants.reset();
     _process_hashtable_ctx_variants = nullptr;
@@ -180,39 +183,10 @@ 
HashJoinProbeOperatorX::HashJoinProbeOperatorX(ObjectPool* pool, const TPlanNode
                                         ? 
tnode.hash_join_node.hash_output_slot_ids
                                         : std::vector<SlotId> {}) {}
 
-Status HashJoinProbeOperatorX::get_block(RuntimeState* state, 
vectorized::Block* block,
-                                         SourceState& source_state) {
-    auto& local_state = 
state->get_local_state(id())->cast<HashJoinProbeLocalState>();
-    local_state.init_for_probe(state);
-    if (need_more_input_data(state)) {
-        local_state._child_block->clear_column_data();
-        RETURN_IF_ERROR(_child_x->get_next_after_projects(state, 
local_state._child_block.get(),
-                                                          
local_state._child_source_state));
-        source_state = local_state._child_source_state;
-        if (local_state._child_block->rows() == 0 &&
-            local_state._child_source_state != SourceState::FINISHED) {
-            return Status::OK();
-        }
-        local_state.prepare_for_next();
-        RETURN_IF_ERROR(
-                push(state, local_state._child_block.get(), 
local_state._child_source_state));
-    }
-
-    if (!need_more_input_data(state)) {
-        RETURN_IF_ERROR(pull(state, block, source_state));
-        if (source_state != SourceState::FINISHED && 
!need_more_input_data(state)) {
-            source_state = SourceState::MORE_DATA;
-        } else if (source_state != SourceState::FINISHED &&
-                   source_state == SourceState::MORE_DATA) {
-            source_state = local_state._child_source_state;
-        }
-    }
-    return Status::OK();
-}
-
 Status HashJoinProbeOperatorX::pull(doris::RuntimeState* state, 
vectorized::Block* output_block,
-                                    SourceState& source_state) {
+                                    SourceState& source_state) const {
     auto& local_state = 
state->get_local_state(id())->cast<HashJoinProbeLocalState>();
+    local_state.init_for_probe(state);
     SCOPED_TIMER(local_state._probe_timer);
     if (local_state._shared_state->short_circuit_for_probe) {
         // If we use a short-circuit strategy, should return empty block 
directly.
@@ -331,7 +305,7 @@ bool 
HashJoinProbeOperatorX::need_more_input_data(RuntimeState* state) const {
 Status HashJoinProbeOperatorX::_do_evaluate(vectorized::Block& block,
                                             vectorized::VExprContextSPtrs& 
exprs,
                                             RuntimeProfile::Counter& 
expr_call_timer,
-                                            std::vector<int>& res_col_ids) {
+                                            std::vector<int>& res_col_ids) 
const {
     for (size_t i = 0; i < exprs.size(); ++i) {
         int result_col_id = -1;
         // execute build column
@@ -349,8 +323,9 @@ Status 
HashJoinProbeOperatorX::_do_evaluate(vectorized::Block& block,
 }
 
 Status HashJoinProbeOperatorX::push(RuntimeState* state, vectorized::Block* 
input_block,
-                                    SourceState source_state) {
+                                    SourceState source_state) const {
     auto& local_state = 
state->get_local_state(id())->cast<HashJoinProbeLocalState>();
+    local_state.prepare_for_next();
     local_state._probe_eos = source_state == SourceState::FINISHED;
     if (input_block->rows() > 0) {
         COUNTER_UPDATE(local_state._probe_rows_counter, input_block->rows());
diff --git a/be/src/pipeline/exec/hashjoin_probe_operator.h 
b/be/src/pipeline/exec/hashjoin_probe_operator.h
index dc1402dce6..9be281d6dc 100644
--- a/be/src/pipeline/exec/hashjoin_probe_operator.h
+++ b/be/src/pipeline/exec/hashjoin_probe_operator.h
@@ -111,18 +111,17 @@ public:
     Status open(RuntimeState* state) override;
     bool can_read(RuntimeState* state) override;
 
-    Status get_block(RuntimeState* state, vectorized::Block* block,
-                     SourceState& source_state) override;
-
-    Status push(RuntimeState* state, vectorized::Block* input_block, 
SourceState source_state);
+    Status push(RuntimeState* state, vectorized::Block* input_block,
+                SourceState source_state) const override;
     Status pull(doris::RuntimeState* state, vectorized::Block* output_block,
-                SourceState& source_state);
+                SourceState& source_state) const override;
 
-    bool need_more_input_data(RuntimeState* state) const;
+    bool need_more_input_data(RuntimeState* state) const override;
 
 private:
     Status _do_evaluate(vectorized::Block& block, 
vectorized::VExprContextSPtrs& exprs,
-                        RuntimeProfile::Counter& expr_call_timer, 
std::vector<int>& res_col_ids);
+                        RuntimeProfile::Counter& expr_call_timer,
+                        std::vector<int>& res_col_ids) const;
     friend class HashJoinProbeLocalState;
     friend struct vectorized::HashJoinProbeContext;
 
diff --git a/be/src/pipeline/exec/join_probe_operator.h 
b/be/src/pipeline/exec/join_probe_operator.h
index 00f4767923..5debf92377 100644
--- a/be/src/pipeline/exec/join_probe_operator.h
+++ b/be/src/pipeline/exec/join_probe_operator.h
@@ -36,6 +36,8 @@ public:
     virtual void add_tuple_is_null_column(vectorized::Block* block) = 0;
 
 protected:
+    template <typename LocalStateType>
+    friend class StatefulOperatorX;
     JoinProbeLocalState(RuntimeState* state, OperatorXBase* parent)
             : Base(state, parent),
               _child_block(vectorized::Block::create_unique()),
@@ -62,9 +64,9 @@ protected:
 };
 
 template <typename LocalStateType>
-class JoinProbeOperatorX : public OperatorX<LocalStateType> {
+class JoinProbeOperatorX : public StatefulOperatorX<LocalStateType> {
 public:
-    using Base = OperatorX<LocalStateType>;
+    using Base = StatefulOperatorX<LocalStateType>;
     JoinProbeOperatorX(ObjectPool* pool, const TPlanNode& tnode, const 
DescriptorTbl& descs);
     virtual Status init(const TPlanNode& tnode, RuntimeState* state) override;
 
diff --git a/be/src/pipeline/exec/nested_loop_join_probe_operator.cpp 
b/be/src/pipeline/exec/nested_loop_join_probe_operator.cpp
index b590bc3eb3..0d025f6669 100644
--- a/be/src/pipeline/exec/nested_loop_join_probe_operator.cpp
+++ b/be/src/pipeline/exec/nested_loop_join_probe_operator.cpp
@@ -509,34 +509,6 @@ Status 
NestedLoopJoinProbeOperatorX::push(doris::RuntimeState* state, vectorized
     return Status::OK();
 }
 
-Status NestedLoopJoinProbeOperatorX::get_block(RuntimeState* state, 
vectorized::Block* block,
-                                               SourceState& source_state) {
-    auto& local_state = 
state->get_local_state(id())->cast<NestedLoopJoinProbeLocalState>();
-    if (need_more_input_data(state)) {
-        local_state._child_block->clear_column_data();
-        RETURN_IF_ERROR(_child_x->get_next_after_projects(state, 
local_state._child_block.get(),
-                                                          
local_state._child_source_state));
-        source_state = local_state._child_source_state;
-        if (local_state._child_block->rows() == 0 &&
-            local_state._child_source_state != SourceState::FINISHED) {
-            return Status::OK();
-        }
-        RETURN_IF_ERROR(
-                push(state, local_state._child_block.get(), 
local_state._child_source_state));
-    }
-
-    if (!need_more_input_data(state)) {
-        RETURN_IF_ERROR(pull(state, block, source_state));
-        if (source_state != SourceState::FINISHED && 
!need_more_input_data(state)) {
-            source_state = SourceState::MORE_DATA;
-        } else if (source_state != SourceState::FINISHED &&
-                   source_state == SourceState::MORE_DATA) {
-            source_state = local_state._child_source_state;
-        }
-    }
-    return Status::OK();
-}
-
 Status NestedLoopJoinProbeOperatorX::pull(RuntimeState* state, 
vectorized::Block* block,
                                           SourceState& source_state) const {
     auto& local_state = 
state->get_local_state(id())->cast<NestedLoopJoinProbeLocalState>();
diff --git a/be/src/pipeline/exec/nested_loop_join_probe_operator.h 
b/be/src/pipeline/exec/nested_loop_join_probe_operator.h
index c43e30e726..17ce4b4de8 100644
--- a/be/src/pipeline/exec/nested_loop_join_probe_operator.h
+++ b/be/src/pipeline/exec/nested_loop_join_probe_operator.h
@@ -210,13 +210,10 @@ public:
     Status open(RuntimeState* state) override;
     bool can_read(RuntimeState* state) override;
 
-    Status get_block(RuntimeState* state, vectorized::Block* block,
-                     SourceState& source_state) override;
-
     Status push(RuntimeState* state, vectorized::Block* input_block,
-                SourceState source_state) const;
+                SourceState source_state) const override;
     Status pull(doris::RuntimeState* state, vectorized::Block* output_block,
-                SourceState& source_state) const;
+                SourceState& source_state) const override;
     const RowDescriptor& intermediate_row_desc() const override {
         return _old_version_flag ? _row_descriptor : *_intermediate_row_desc;
     }
@@ -227,7 +224,7 @@ public:
                        : *_output_row_desc;
     }
 
-    bool need_more_input_data(RuntimeState* state) const;
+    bool need_more_input_data(RuntimeState* state) const override;
 
 private:
     friend class NestedLoopJoinProbeLocalState;
diff --git a/be/src/pipeline/exec/repeat_operator.cpp 
b/be/src/pipeline/exec/repeat_operator.cpp
index 3b47201292..e382477d78 100644
--- a/be/src/pipeline/exec/repeat_operator.cpp
+++ b/be/src/pipeline/exec/repeat_operator.cpp
@@ -104,45 +104,19 @@ bool RepeatOperatorX::need_more_input_data(RuntimeState* 
state) const {
     auto& local_state = state->get_local_state(id())->cast<RepeatLocalState>();
     return !local_state._child_block->rows() && !local_state._child_eos;
 }
-Status RepeatOperatorX::get_block(RuntimeState* state, vectorized::Block* 
block,
-                                  SourceState& source_state) {
-    auto& local_state = state->get_local_state(id())->cast<RepeatLocalState>();
-    if (need_more_input_data(state)) {
-        local_state._child_block->clear_column_data();
-        RETURN_IF_ERROR(_child_x->get_next_after_projects(state, 
local_state._child_block.get(),
-                                                          
local_state._child_source_state));
-        source_state = local_state._child_source_state;
-        if (local_state._child_block->rows() == 0 &&
-            local_state._child_source_state != SourceState::FINISHED) {
-            return Status::OK();
-        }
-        RETURN_IF_ERROR(
-                push(state, local_state._child_block.get(), 
local_state._child_source_state));
-    }
-
-    if (!need_more_input_data(state)) {
-        RETURN_IF_ERROR(pull(state, block, source_state));
-        if (source_state != SourceState::FINISHED && 
!need_more_input_data(state)) {
-            source_state = SourceState::MORE_DATA;
-        } else if (source_state != SourceState::FINISHED &&
-                   source_state == SourceState::MORE_DATA) {
-            source_state = local_state._child_source_state;
-        }
-    }
-    return Status::OK();
-}
 
-Status RepeatOperatorX::get_repeated_block(vectorized::Block* child_block, int 
repeat_id_idx,
-                                           vectorized::Block* output_block) {
+Status RepeatLocalState::get_repeated_block(vectorized::Block* child_block, 
int repeat_id_idx,
+                                            vectorized::Block* output_block) {
+    auto& p = _parent->cast<RepeatOperatorX>();
     DCHECK(child_block != nullptr);
     DCHECK_EQ(output_block->rows(), 0);
 
     size_t child_column_size = child_block->columns();
-    size_t column_size = _output_slots.size();
+    size_t column_size = p._output_slots.size();
     DCHECK_LT(child_column_size, column_size);
-    vectorized::MutableBlock m_block =
-            
vectorized::VectorizedUtils::build_mutable_mem_reuse_block(output_block, 
_output_slots);
-    vectorized::MutableColumns& columns = m_block.mutable_columns();
+    auto m_block = 
vectorized::VectorizedUtils::build_mutable_mem_reuse_block(output_block,
+                                                                              
p._output_slots);
+    auto& columns = m_block.mutable_columns();
     /* Fill all slots according to child, for example:select tc1,tc2,sum(tc3) 
from t1 group by grouping sets((tc1),(tc2));
      * insert into t1 values(1,2,1),(1,3,1),(2,1,1),(3,1,1);
      * slot_id_set_list=[[0],[1]],repeat_id_idx=0,
@@ -153,14 +127,14 @@ Status 
RepeatOperatorX::get_repeated_block(vectorized::Block* child_block, int r
     for (size_t i = 0; i < child_column_size; i++) {
         const vectorized::ColumnWithTypeAndName& src_column = 
child_block->get_by_position(i);
 
-        std::set<SlotId>& repeat_ids = _slot_id_set_list[repeat_id_idx];
+        std::set<SlotId>& repeat_ids = p._slot_id_set_list[repeat_id_idx];
         bool is_repeat_slot =
-                _all_slot_ids.find(_output_slots[cur_col]->id()) != 
_all_slot_ids.end();
-        bool is_set_null_slot = repeat_ids.find(_output_slots[cur_col]->id()) 
== repeat_ids.end();
+                p._all_slot_ids.find(p._output_slots[cur_col]->id()) != 
p._all_slot_ids.end();
+        bool is_set_null_slot = 
repeat_ids.find(p._output_slots[cur_col]->id()) == repeat_ids.end();
         const auto row_size = src_column.column->size();
 
         if (is_repeat_slot) {
-            DCHECK(_output_slots[cur_col]->is_nullable());
+            DCHECK(p._output_slots[cur_col]->is_nullable());
             auto* nullable_column =
                     
reinterpret_cast<vectorized::ColumnNullable*>(columns[cur_col].get());
             auto& null_map = nullable_column->get_null_map_data();
@@ -187,14 +161,14 @@ Status 
RepeatOperatorX::get_repeated_block(vectorized::Block* child_block, int r
     }
 
     // Fill grouping ID to block
-    for (auto slot_idx = 0; slot_idx < _grouping_list.size(); slot_idx++) {
-        DCHECK_LT(slot_idx, _output_tuple_desc->slots().size());
-        const SlotDescriptor* _virtual_slot_desc = 
_output_tuple_desc->slots()[cur_col];
-        DCHECK_EQ(_virtual_slot_desc->type().type, 
_output_slots[cur_col]->type().type);
-        DCHECK_EQ(_virtual_slot_desc->col_name(), 
_output_slots[cur_col]->col_name());
-        int64_t val = _grouping_list[slot_idx][repeat_id_idx];
+    for (auto slot_idx = 0; slot_idx < p._grouping_list.size(); slot_idx++) {
+        DCHECK_LT(slot_idx, p._output_tuple_desc->slots().size());
+        const SlotDescriptor* _virtual_slot_desc = 
p._output_tuple_desc->slots()[cur_col];
+        DCHECK_EQ(_virtual_slot_desc->type().type, 
p._output_slots[cur_col]->type().type);
+        DCHECK_EQ(_virtual_slot_desc->col_name(), 
p._output_slots[cur_col]->col_name());
+        int64_t val = p._grouping_list[slot_idx][repeat_id_idx];
         auto* column_ptr = columns[cur_col].get();
-        DCHECK(!_output_slots[cur_col]->is_nullable());
+        DCHECK(!p._output_slots[cur_col]->is_nullable());
 
         auto* col = 
assert_cast<vectorized::ColumnVector<vectorized::Int64>*>(column_ptr);
         for (size_t i = 0; i < child_block->rows(); ++i) {
@@ -209,7 +183,7 @@ Status 
RepeatOperatorX::get_repeated_block(vectorized::Block* child_block, int r
 }
 
 Status RepeatOperatorX::push(RuntimeState* state, vectorized::Block* 
input_block,
-                             SourceState& source_state) {
+                             SourceState source_state) const {
     auto& local_state = state->get_local_state(id())->cast<RepeatLocalState>();
     local_state._child_eos = source_state == SourceState::FINISHED;
     auto& _intermediate_block = local_state._intermediate_block;
@@ -234,8 +208,9 @@ Status RepeatOperatorX::push(RuntimeState* state, 
vectorized::Block* input_block
 
     return Status::OK();
 }
+
 Status RepeatOperatorX::pull(doris::RuntimeState* state, vectorized::Block* 
output_block,
-                             SourceState& source_state) {
+                             SourceState& source_state) const {
     auto& local_state = state->get_local_state(id())->cast<RepeatLocalState>();
     auto& _repeat_id_idx = local_state._repeat_id_idx;
     auto& _child_block = *local_state._child_block;
@@ -249,15 +224,15 @@ Status RepeatOperatorX::pull(doris::RuntimeState* state, 
vectorized::Block* outp
     DCHECK(output_block->rows() == 0);
 
     if (_intermediate_block && _intermediate_block->rows() > 0) {
-        RETURN_IF_ERROR(
-                get_repeated_block(_intermediate_block.get(), _repeat_id_idx, 
output_block));
+        
RETURN_IF_ERROR(local_state.get_repeated_block(_intermediate_block.get(), 
_repeat_id_idx,
+                                                       output_block));
 
         _repeat_id_idx++;
 
         int size = _repeat_id_list.size();
         if (_repeat_id_idx >= size) {
             _intermediate_block->clear();
-            release_block_memory(_child_block);
+            
_child_block.clear_column_data(_child_x->row_desc().num_materialized_slots());
             _repeat_id_idx = 0;
         }
     }
@@ -270,4 +245,5 @@ Status RepeatOperatorX::pull(doris::RuntimeState* state, 
vectorized::Block* outp
     COUNTER_SET(local_state._rows_returned_counter, 
local_state._num_rows_returned);
     return Status::OK();
 }
+
 } // namespace doris::pipeline
diff --git a/be/src/pipeline/exec/repeat_operator.h 
b/be/src/pipeline/exec/repeat_operator.h
index f81fd19ab2..976b704ccd 100644
--- a/be/src/pipeline/exec/repeat_operator.h
+++ b/be/src/pipeline/exec/repeat_operator.h
@@ -55,8 +55,13 @@ public:
 
     Status init(RuntimeState* state, LocalStateInfo& info) override;
 
+    Status get_repeated_block(vectorized::Block* child_block, int 
repeat_id_idx,
+                              vectorized::Block* output_block);
+
 private:
     friend class RepeatOperatorX;
+    template <typename LocalStateType>
+    friend class StatefulOperatorX;
     std::unique_ptr<vectorized::Block> _child_block;
     SourceState _child_source_state;
     bool _child_eos;
@@ -64,25 +69,24 @@ private:
     std::unique_ptr<vectorized::Block> _intermediate_block {};
     vectorized::VExprContextSPtrs _expr_ctxs;
 };
-class RepeatOperatorX final : public OperatorX<RepeatLocalState> {
+class RepeatOperatorX final : public StatefulOperatorX<RepeatLocalState> {
 public:
-    using Base = OperatorX<RepeatLocalState>;
+    using Base = StatefulOperatorX<RepeatLocalState>;
     RepeatOperatorX(ObjectPool* pool, const TPlanNode& tnode, const 
DescriptorTbl& descs);
-    Status get_block(RuntimeState* state, vectorized::Block* block,
-                     SourceState& source_state) override;
     Status init(const TPlanNode& tnode, RuntimeState* state) override;
 
     Status prepare(RuntimeState* state) override;
     Status open(RuntimeState* state) override;
 
+    bool need_more_input_data(RuntimeState* state) const override;
+    Status pull(RuntimeState* state, vectorized::Block* output_block,
+                SourceState& source_state) const override;
+    Status push(RuntimeState* state, vectorized::Block* input_block,
+                SourceState source_state) const override;
+
 private:
     friend class RepeatLocalState;
-    Status get_repeated_block(vectorized::Block* child_block, int 
repeat_id_idx,
-                              vectorized::Block* output_block);
-    bool need_more_input_data(RuntimeState* state) const;
 
-    Status pull(RuntimeState* state, vectorized::Block* output_block, 
SourceState& source_state);
-    Status push(RuntimeState* state, vectorized::Block* input_block, 
SourceState& source_state);
     // Slot id set used to indicate those slots need to set to null.
     std::vector<std::set<SlotId>> _slot_id_set_list;
     // all slot id
diff --git a/be/src/pipeline/pipeline_x/operator.cpp 
b/be/src/pipeline/pipeline_x/operator.cpp
index 8b720f1a0d..067d975496 100644
--- a/be/src/pipeline/pipeline_x/operator.cpp
+++ b/be/src/pipeline/pipeline_x/operator.cpp
@@ -21,6 +21,7 @@
 #include "pipeline/exec/aggregation_source_operator.h"
 #include "pipeline/exec/analytic_sink_operator.h"
 #include "pipeline/exec/analytic_source_operator.h"
+#include "pipeline/exec/assert_num_rows_operator.h"
 #include "pipeline/exec/exchange_sink_operator.h"
 #include "pipeline/exec/exchange_source_operator.h"
 #include "pipeline/exec/hashjoin_build_sink.h"
@@ -158,10 +159,6 @@ Status 
OperatorXBase::get_next_after_projects(RuntimeState* state, vectorized::B
     return get_block(state, block, source_state);
 }
 
-void OperatorXBase::release_block_memory(vectorized::Block& block) {
-    block.clear_column_data(_child_x->row_desc().num_materialized_slots());
-}
-
 bool PipelineXLocalStateBase::reached_limit() const {
     return _parent->_limit != -1 && _num_rows_returned >= _parent->_limit;
 }
@@ -217,6 +214,44 @@ Status 
OperatorX<LocalStateType>::setup_local_state(RuntimeState* state, LocalSt
     return local_state->init(state, info);
 }
 
+template <typename LocalStateType>
+Status StreamingOperatorX<LocalStateType>::get_block(RuntimeState* state, 
vectorized::Block* block,
+                                                     SourceState& 
source_state) {
+    
RETURN_IF_ERROR(OperatorX<LocalStateType>::_child_x->get_next_after_projects(state,
 block,
+                                                                               
  source_state));
+    return pull(state, block, source_state);
+}
+
+template <typename LocalStateType>
+Status StatefulOperatorX<LocalStateType>::get_block(RuntimeState* state, 
vectorized::Block* block,
+                                                    SourceState& source_state) 
{
+    auto& local_state = state->get_local_state(OperatorX<LocalStateType>::id())
+                                ->template cast<LocalStateType>();
+    if (need_more_input_data(state)) {
+        local_state._child_block->clear_column_data();
+        
RETURN_IF_ERROR(OperatorX<LocalStateType>::_child_x->get_next_after_projects(
+                state, local_state._child_block.get(), 
local_state._child_source_state));
+        source_state = local_state._child_source_state;
+        if (local_state._child_block->rows() == 0 &&
+            local_state._child_source_state != SourceState::FINISHED) {
+            return Status::OK();
+        }
+        RETURN_IF_ERROR(
+                push(state, local_state._child_block.get(), 
local_state._child_source_state));
+    }
+
+    if (!need_more_input_data(state)) {
+        RETURN_IF_ERROR(pull(state, block, source_state));
+        if (source_state != SourceState::FINISHED && 
!need_more_input_data(state)) {
+            source_state = SourceState::MORE_DATA;
+        } else if (source_state != SourceState::FINISHED &&
+                   source_state == SourceState::MORE_DATA) {
+            source_state = local_state._child_source_state;
+        }
+    }
+    return Status::OK();
+}
+
 #define DECLARE_OPERATOR_X(LOCAL_STATE) template class 
DataSinkOperatorX<LOCAL_STATE>;
 DECLARE_OPERATOR_X(HashJoinBuildSinkLocalState)
 DECLARE_OPERATOR_X(ResultSinkLocalState)
@@ -238,7 +273,14 @@ DECLARE_OPERATOR_X(AggLocalState)
 DECLARE_OPERATOR_X(ExchangeLocalState)
 DECLARE_OPERATOR_X(RepeatLocalState)
 DECLARE_OPERATOR_X(NestedLoopJoinProbeLocalState)
+DECLARE_OPERATOR_X(AssertNumRowsLocalState)
 
 #undef DECLARE_OPERATOR_X
 
+template class StreamingOperatorX<AssertNumRowsLocalState>;
+
+template class StatefulOperatorX<HashJoinProbeLocalState>;
+template class StatefulOperatorX<RepeatLocalState>;
+template class StatefulOperatorX<NestedLoopJoinProbeLocalState>;
+
 } // namespace doris::pipeline
diff --git a/be/src/pipeline/pipeline_x/operator.h 
b/be/src/pipeline/pipeline_x/operator.h
index 4b825bb2b2..d0dddda09b 100644
--- a/be/src/pipeline/pipeline_x/operator.h
+++ b/be/src/pipeline/pipeline_x/operator.h
@@ -83,7 +83,7 @@ public:
     RuntimeState* state() { return _state; }
     vectorized::VExprContextSPtrs& conjuncts() { return _conjuncts; }
     vectorized::VExprContextSPtrs& projections() { return _projections; }
-    int64_t num_rows_returned() const { return _num_rows_returned; }
+    [[nodiscard]] int64_t num_rows_returned() const { return 
_num_rows_returned; }
     void add_num_rows_returned(int64_t delta) { _num_rows_returned += delta; }
     void set_num_rows_returned(int64_t value) { _num_rows_returned = value; }
 
@@ -228,11 +228,6 @@ public:
                           vectorized::Block* output_block) const;
 
 protected:
-    /// Release all memory of block which got from child. The block
-    // 1. clear mem of valid column get from child, make sure child can reuse 
the mem
-    // 2. delete and release the column which create by function all and other 
reason
-    void release_block_memory(vectorized::Block& block);
-
     template <typename Dependency>
     friend class PipelineXLocalState;
     friend class PipelineXLocalStateBase;
@@ -500,4 +495,46 @@ protected:
     typename DependencyType::SharedState* _shared_state;
 };
 
+/**
+ * StreamingOperatorX indicates operators which always processes block in 
streaming way (one-in-one-out).
+ */
+template <typename LocalStateType>
+class StreamingOperatorX : public OperatorX<LocalStateType> {
+public:
+    StreamingOperatorX(ObjectPool* pool, const TPlanNode& tnode, const 
DescriptorTbl& descs)
+            : OperatorX<LocalStateType>(pool, tnode, descs) {}
+    virtual ~StreamingOperatorX() = default;
+
+    Status get_block(RuntimeState* state, vectorized::Block* block,
+                     SourceState& source_state) override;
+
+    virtual Status pull(RuntimeState* state, vectorized::Block* block,
+                        SourceState& source_state) = 0;
+};
+
+/**
+ * StatefulOperatorX indicates the operators with some states inside.
+ *
+ * Specifically, we called an operator stateful if an operator can determine 
its output by itself.
+ * For example, hash join probe operator is a typical StatefulOperator. When 
it gets a block from probe side, it will hold this block inside (e.g. 
_child_block).
+ * If there are still remain rows in probe block, we can get output block by 
calling `get_block` without any data from its child.
+ * In a nutshell, it is a one-to-many relation between input blocks and output 
blocks for StatefulOperator.
+ */
+template <typename LocalStateType>
+class StatefulOperatorX : public OperatorX<LocalStateType> {
+public:
+    StatefulOperatorX(ObjectPool* pool, const TPlanNode& tnode, const 
DescriptorTbl& descs)
+            : OperatorX<LocalStateType>(pool, tnode, descs) {}
+    virtual ~StatefulOperatorX() = default;
+
+    [[nodiscard]] Status get_block(RuntimeState* state, vectorized::Block* 
block,
+                                   SourceState& source_state) override;
+
+    [[nodiscard]] virtual Status pull(RuntimeState* state, vectorized::Block* 
block,
+                                      SourceState& source_state) const = 0;
+    [[nodiscard]] virtual Status push(RuntimeState* state, vectorized::Block* 
input_block,
+                                      SourceState source_state) const = 0;
+    [[nodiscard]] virtual bool need_more_input_data(RuntimeState* state) const 
= 0;
+};
+
 } // namespace doris::pipeline
diff --git a/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp 
b/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp
index 3ad77fde35..916b88c7f7 100644
--- a/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp
+++ b/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp
@@ -45,6 +45,7 @@
 #include "pipeline/exec/aggregation_source_operator.h"
 #include "pipeline/exec/analytic_sink_operator.h"
 #include "pipeline/exec/analytic_source_operator.h"
+#include "pipeline/exec/assert_num_rows_operator.h"
 #include "pipeline/exec/data_queue.h"
 #include "pipeline/exec/datagen_operator.h"
 #include "pipeline/exec/exchange_sink_operator.h"
@@ -628,6 +629,11 @@ Status 
PipelineXFragmentContext::_create_operator(ObjectPool* pool, const TPlanN
         RETURN_IF_ERROR(cur_pipe->add_operator(op));
         break;
     }
+    case TPlanNodeType::ASSERT_NUM_ROWS_NODE: {
+        op.reset(new AssertNumRowsOperatorX(pool, tnode, descs));
+        RETURN_IF_ERROR(cur_pipe->add_operator(op));
+        break;
+    }
     default:
         return Status::InternalError("Unsupported exec type in pipelineX: {}",
                                      print_plan_node_type(tnode.node_type));


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to