This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new fb5454c213c [improvement](spill) avoid unnecessary spilling in hash 
join build phase (#33277)
fb5454c213c is described below

commit fb5454c213c338bd7d6259803c5536d51f9d164a
Author: Jerry Hu <mrh...@gmail.com>
AuthorDate: Mon Apr 8 07:31:13 2024 +0800

    [improvement](spill) avoid unnecessary spilling in hash join build phase 
(#33277)
---
 be/src/pipeline/exec/hashjoin_build_sink.h         |   1 +
 .../exec/partitioned_hash_join_probe_operator.cpp  |  90 +++++----
 .../exec/partitioned_hash_join_probe_operator.h    |  12 +-
 .../exec/partitioned_hash_join_sink_operator.cpp   | 217 ++++++++++++++++-----
 .../exec/partitioned_hash_join_sink_operator.h     |  17 ++
 be/src/pipeline/pipeline_x/dependency.h            |   2 +
 .../pipeline_x/pipeline_x_fragment_context.cpp     |  28 ++-
 be/src/vec/spill/spill_writer.cpp                  |   7 +-
 8 files changed, 283 insertions(+), 91 deletions(-)

diff --git a/be/src/pipeline/exec/hashjoin_build_sink.h 
b/be/src/pipeline/exec/hashjoin_build_sink.h
index 0849027e32a..2712edc838d 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.h
+++ b/be/src/pipeline/exec/hashjoin_build_sink.h
@@ -81,6 +81,7 @@ protected:
                                 vectorized::ColumnRawPtrs& raw_ptrs,
                                 const std::vector<int>& res_col_ids);
     friend class HashJoinBuildSinkOperatorX;
+    friend class PartitionedHashJoinSinkLocalState;
     template <class HashTableContext, typename Parent>
     friend struct vectorized::ProcessHashTableBuild;
     template <typename Parent>
diff --git a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp 
b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
index 0c74834ab61..a8eb926243f 100644
--- a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
+++ b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
@@ -487,22 +487,16 @@ Status PartitionedHashJoinProbeOperatorX::init(const 
TPlanNode& tnode, RuntimeSt
         _probe_exprs.emplace_back(conjunct.left);
     }
 
-    _sink_operator =
-            std::make_unique<HashJoinBuildSinkOperatorX>(_pool, 0, tnode_, 
_descriptor_tbl, false);
-    _probe_operator = std::make_unique<HashJoinProbeOperatorX>(_pool, tnode_, 
0, _descriptor_tbl);
-    RETURN_IF_ERROR(_sink_operator->init(tnode_, state));
-    return _probe_operator->init(tnode_, state);
+    return Status::OK();
 }
 Status PartitionedHashJoinProbeOperatorX::prepare(RuntimeState* state) {
     // to avoid prepare _child_x twice
     auto child_x = std::move(_child_x);
     RETURN_IF_ERROR(JoinProbeOperatorX::prepare(state));
-    RETURN_IF_ERROR(_probe_operator->set_child(child_x));
+    RETURN_IF_ERROR(_inner_probe_operator->set_child(child_x));
     DCHECK(_build_side_child != nullptr);
-    _probe_operator->set_build_side_child(_build_side_child);
-    RETURN_IF_ERROR(_sink_operator->set_child(_build_side_child));
-    RETURN_IF_ERROR(_probe_operator->prepare(state));
-    RETURN_IF_ERROR(_sink_operator->prepare(state));
+    _inner_probe_operator->set_build_side_child(_build_side_child);
+    RETURN_IF_ERROR(_inner_probe_operator->prepare(state));
     _child_x = std::move(child_x);
     return Status::OK();
 }
@@ -511,8 +505,7 @@ Status 
PartitionedHashJoinProbeOperatorX::open(RuntimeState* state) {
     // to avoid open _child_x twice
     auto child_x = std::move(_child_x);
     RETURN_IF_ERROR(JoinProbeOperatorX::open(state));
-    RETURN_IF_ERROR(_probe_operator->open(state));
-    RETURN_IF_ERROR(_sink_operator->open(state));
+    RETURN_IF_ERROR(_inner_probe_operator->open(state));
     _child_x = std::move(child_x);
     return Status::OK();
 }
@@ -570,6 +563,15 @@ Status 
PartitionedHashJoinProbeOperatorX::push(RuntimeState* state, vectorized::
     return Status::OK();
 }
 
+Status 
PartitionedHashJoinProbeOperatorX::_setup_internal_operator_for_non_spill(
+        PartitionedHashJoinProbeLocalState& local_state, RuntimeState* state) {
+    DCHECK(local_state._shared_state->inner_runtime_state);
+    local_state._runtime_state = 
std::move(local_state._shared_state->inner_runtime_state);
+    local_state._in_mem_shared_state_sptr =
+            std::move(local_state._shared_state->inner_shared_state);
+    return Status::OK();
+}
+
 Status PartitionedHashJoinProbeOperatorX::_setup_internal_operators(
         PartitionedHashJoinProbeLocalState& local_state, RuntimeState* state) 
const {
     if (local_state._runtime_state) {
@@ -589,13 +591,14 @@ Status 
PartitionedHashJoinProbeOperatorX::_setup_internal_operators(
     local_state._runtime_state->set_pipeline_x_runtime_filter_mgr(
             state->local_runtime_filter_mgr());
 
-    local_state._in_mem_shared_state_sptr = 
_sink_operator->create_shared_state();
+    local_state._in_mem_shared_state_sptr = 
_inner_sink_operator->create_shared_state();
 
     // set sink local state
     LocalSinkStateInfo info {0,  local_state._internal_runtime_profile.get(),
                              -1, local_state._in_mem_shared_state_sptr.get(),
                              {}, {}};
-    
RETURN_IF_ERROR(_sink_operator->setup_local_state(local_state._runtime_state.get(),
 info));
+    RETURN_IF_ERROR(
+            
_inner_sink_operator->setup_local_state(local_state._runtime_state.get(), 
info));
 
     LocalStateInfo state_info {local_state._internal_runtime_profile.get(),
                                {},
@@ -603,14 +606,14 @@ Status 
PartitionedHashJoinProbeOperatorX::_setup_internal_operators(
                                {},
                                0};
     RETURN_IF_ERROR(
-            
_probe_operator->setup_local_state(local_state._runtime_state.get(), 
state_info));
+            
_inner_probe_operator->setup_local_state(local_state._runtime_state.get(), 
state_info));
 
     auto* sink_local_state = 
local_state._runtime_state->get_sink_local_state();
     DCHECK(sink_local_state != nullptr);
     RETURN_IF_ERROR(sink_local_state->open(state));
 
     auto* probe_local_state =
-            
local_state._runtime_state->get_local_state(_probe_operator->operator_id());
+            
local_state._runtime_state->get_local_state(_inner_probe_operator->operator_id());
     DCHECK(probe_local_state != nullptr);
     RETURN_IF_ERROR(probe_local_state->open(state));
 
@@ -621,7 +624,7 @@ Status 
PartitionedHashJoinProbeOperatorX::_setup_internal_operators(
         block = partitioned_block->to_block();
         partitioned_block.reset();
     }
-    RETURN_IF_ERROR(_sink_operator->sink(local_state._runtime_state.get(), 
&block, true));
+    
RETURN_IF_ERROR(_inner_sink_operator->sink(local_state._runtime_state.get(), 
&block, true));
     LOG(INFO) << "internal build operator finished, node id: " << id()
               << ", task id: " << state->task_id()
               << ", partition: " << local_state._partition_cursor;
@@ -662,7 +665,7 @@ Status 
PartitionedHashJoinProbeOperatorX::pull(doris::RuntimeState* state,
     bool in_mem_eos_;
     auto* runtime_state = local_state._runtime_state.get();
     auto& probe_blocks = local_state._probe_blocks[partition_index];
-    while (_probe_operator->need_more_input_data(runtime_state)) {
+    while (_inner_probe_operator->need_more_input_data(runtime_state)) {
         if (probe_blocks.empty()) {
             *eos = false;
             bool has_data = false;
@@ -670,7 +673,7 @@ Status 
PartitionedHashJoinProbeOperatorX::pull(doris::RuntimeState* state,
                     local_state.recovery_probe_blocks_from_disk(state, 
partition_index, has_data));
             if (!has_data) {
                 vectorized::Block block;
-                RETURN_IF_ERROR(_probe_operator->push(runtime_state, &block, 
true));
+                RETURN_IF_ERROR(_inner_probe_operator->push(runtime_state, 
&block, true));
                 break;
             } else {
                 return Status::OK();
@@ -679,11 +682,11 @@ Status 
PartitionedHashJoinProbeOperatorX::pull(doris::RuntimeState* state,
 
         auto block = std::move(probe_blocks.back());
         probe_blocks.pop_back();
-        RETURN_IF_ERROR(_probe_operator->push(runtime_state, &block, false));
+        RETURN_IF_ERROR(_inner_probe_operator->push(runtime_state, &block, 
false));
     }
 
-    RETURN_IF_ERROR(
-            _probe_operator->pull(local_state._runtime_state.get(), 
output_block, &in_mem_eos_));
+    
RETURN_IF_ERROR(_inner_probe_operator->pull(local_state._runtime_state.get(), 
output_block,
+                                                &in_mem_eos_));
 
     *eos = false;
     if (in_mem_eos_) {
@@ -701,7 +704,13 @@ Status 
PartitionedHashJoinProbeOperatorX::pull(doris::RuntimeState* state,
 
 bool PartitionedHashJoinProbeOperatorX::need_more_input_data(RuntimeState* 
state) const {
     auto& local_state = get_local_state(state);
-    return !local_state._child_eos;
+    if (local_state._shared_state->need_to_spill) {
+        return !local_state._child_eos;
+    } else if (local_state._runtime_state) {
+        return 
_inner_probe_operator->need_more_input_data(local_state._runtime_state.get());
+    } else {
+        return true;
+    }
 }
 
 bool PartitionedHashJoinProbeOperatorX::need_data_from_children(RuntimeState* 
state) const {
@@ -793,7 +802,7 @@ void 
PartitionedHashJoinProbeOperatorX::_update_profile_from_internal_states(
         auto* sink_local_state = 
local_state._runtime_state->get_sink_local_state();
         local_state.update_build_profile(sink_local_state->profile());
         auto* probe_local_state =
-                
local_state._runtime_state->get_local_state(_probe_operator->operator_id());
+                
local_state._runtime_state->get_local_state(_inner_probe_operator->operator_id());
         local_state.update_probe_profile(probe_local_state->profile());
     }
 }
@@ -803,14 +812,12 @@ Status 
PartitionedHashJoinProbeOperatorX::get_block(RuntimeState* state, vectori
     *eos = false;
     auto& local_state = get_local_state(state);
     SCOPED_TIMER(local_state.exec_time_counter());
+    const auto need_to_spill = local_state._shared_state->need_to_spill;
     if (need_more_input_data(state)) {
-        local_state._child_block->clear_column_data();
-
-        if (_should_revoke_memory(state)) {
+        if (need_to_spill && _should_revoke_memory(state)) {
             bool wait_for_io = false;
             RETURN_IF_ERROR(_revoke_memory(state, wait_for_io));
             if (wait_for_io) {
-                local_state._shared_state->need_to_spill = true;
                 return Status::OK();
             }
         }
@@ -818,20 +825,37 @@ Status 
PartitionedHashJoinProbeOperatorX::get_block(RuntimeState* state, vectori
         RETURN_IF_ERROR(_child_x->get_block_after_projects(state, 
local_state._child_block.get(),
                                                            
&local_state._child_eos));
 
-        if (local_state._child_eos) {
+        if (need_to_spill && local_state._child_eos) {
             RETURN_IF_ERROR(local_state.finish_spilling(0));
-        } else if (local_state._child_block->rows() == 0) {
-            return Status::OK();
         }
-        {
+
+        Defer defer([&] { local_state._child_block->clear_column_data(); });
+        if (need_to_spill) {
             SCOPED_TIMER(local_state.exec_time_counter());
             RETURN_IF_ERROR(push(state, local_state._child_block.get(), 
local_state._child_eos));
+        } else {
+            if (UNLIKELY(!local_state._runtime_state)) {
+                
RETURN_IF_ERROR(_setup_internal_operator_for_non_spill(local_state, state));
+            }
+
+            
RETURN_IF_ERROR(_inner_probe_operator->push(local_state._runtime_state.get(),
+                                                        
local_state._child_block.get(),
+                                                        
local_state._child_eos));
         }
     }
 
     if (!need_more_input_data(state)) {
         SCOPED_TIMER(local_state.exec_time_counter());
-        RETURN_IF_ERROR(pull(state, block, eos));
+        if (need_to_spill) {
+            RETURN_IF_ERROR(pull(state, block, eos));
+        } else {
+            RETURN_IF_ERROR(
+                    
_inner_probe_operator->pull(local_state._runtime_state.get(), block, eos));
+            if (*eos) {
+                local_state._runtime_state.reset();
+            }
+        }
+
         local_state.add_num_rows_returned(block->rows());
         if (*eos) {
             _update_profile_from_internal_states(local_state);
diff --git a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.h 
b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.h
index 143576e1b86..96a5cf96e34 100644
--- a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.h
+++ b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.h
@@ -182,6 +182,12 @@ public:
 
     bool need_data_from_children(RuntimeState* state) const override;
 
+    void set_inner_operators(const 
std::shared_ptr<HashJoinBuildSinkOperatorX>& sink_operator,
+                             const std::shared_ptr<HashJoinProbeOperatorX>& 
probe_operator) {
+        _inner_sink_operator = sink_operator;
+        _inner_probe_operator = probe_operator;
+    }
+
 private:
     Status _revoke_memory(RuntimeState* state, bool& wait_for_io);
 
@@ -189,6 +195,8 @@ private:
 
     [[nodiscard]] Status 
_setup_internal_operators(PartitionedHashJoinProbeLocalState& local_state,
                                                    RuntimeState* state) const;
+    [[nodiscard]] Status _setup_internal_operator_for_non_spill(
+            PartitionedHashJoinProbeLocalState& local_state, RuntimeState* 
state);
 
     bool _should_revoke_memory(RuntimeState* state) const;
 
@@ -197,8 +205,8 @@ private:
 
     const TJoinDistributionType::type _join_distribution;
 
-    std::unique_ptr<HashJoinBuildSinkOperatorX> _sink_operator;
-    std::unique_ptr<HashJoinProbeOperatorX> _probe_operator;
+    std::shared_ptr<HashJoinBuildSinkOperatorX> _inner_sink_operator;
+    std::shared_ptr<HashJoinProbeOperatorX> _inner_probe_operator;
 
     // probe expr
     std::vector<TExpr> _probe_exprs;
diff --git a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp 
b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp
index 8b9accd30ad..4fd399464c1 100644
--- a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp
+++ b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp
@@ -30,6 +30,8 @@ Status 
PartitionedHashJoinSinkLocalState::init(doris::RuntimeState* state,
     _shared_state->partitioned_build_blocks.resize(p._partition_count);
     _shared_state->spilled_streams.resize(p._partition_count);
 
+    _internal_runtime_profile.reset(new RuntimeProfile("internal_profile"));
+
     _partitioner = std::make_unique<PartitionerType>(p._partition_count);
     RETURN_IF_ERROR(_partitioner->init(p._build_exprs));
 
@@ -55,10 +57,69 @@ Status 
PartitionedHashJoinSinkLocalState::close(RuntimeState* state, Status exec
     return PipelineXSpillSinkLocalState::close(state, exec_status);
 }
 
+size_t PartitionedHashJoinSinkLocalState::revocable_mem_size(RuntimeState* 
state) const {
+    /// If no need to spill, all rows were sunk into the 
`_inner_sink_operator` without partitioned.
+    if (!_shared_state->need_to_spill) {
+        if (_shared_state->inner_shared_state && 
_shared_state->inner_shared_state->build_block) {
+            return 
_shared_state->inner_shared_state->build_block->allocated_bytes();
+        } else if (_shared_state->inner_runtime_state) {
+            auto inner_sink_state = 
_shared_state->inner_runtime_state->get_sink_local_state();
+
+            if (inner_sink_state) {
+                auto& build_block = 
reinterpret_cast<HashJoinBuildSinkLocalState*>(inner_sink_state)
+                                            ->_build_side_mutable_block;
+                return build_block.allocated_bytes();
+            }
+        }
+        return 0;
+    }
+
+    size_t mem_size = 0;
+    auto& partitioned_blocks = _shared_state->partitioned_build_blocks;
+    for (auto& block : partitioned_blocks) {
+        if (block) {
+            auto block_bytes = block->allocated_bytes();
+            if (block_bytes >= 
vectorized::SpillStream::MIN_SPILL_WRITE_BATCH_MEM) {
+                mem_size += block_bytes;
+            }
+        }
+    }
+    return mem_size;
+}
+
 Status PartitionedHashJoinSinkLocalState::revoke_memory(RuntimeState* state) {
     LOG(INFO) << "hash join sink " << _parent->id() << " revoke_memory"
               << ", eos: " << _child_eos;
     DCHECK_EQ(_spilling_streams_count, 0);
+
+    if (!_shared_state->need_to_spill) {
+        auto& p = _parent->cast<PartitionedHashJoinSinkOperatorX>();
+        _shared_state->inner_shared_state->hash_table_variants.reset();
+        auto row_desc = p._child_x->row_desc();
+        auto build_block = 
std::move(_shared_state->inner_shared_state->build_block);
+        if (!build_block) {
+            build_block = vectorized::Block::create_shared();
+            auto inner_sink_state = 
_shared_state->inner_runtime_state->get_sink_local_state();
+            if (inner_sink_state) {
+                auto& mutable_block =
+                        
reinterpret_cast<HashJoinBuildSinkLocalState*>(inner_sink_state)
+                                ->_build_side_mutable_block;
+                *build_block = mutable_block.to_block();
+                LOG(INFO) << "hash join sink will revoke build mutable block: "
+                          << build_block->allocated_bytes();
+            }
+        }
+
+        /// Here need to skip the first row in build block.
+        /// The first row in build block is generated by 
`HashJoinBuildSinkOperatorX::sink`.
+        if (build_block->rows() > 1) {
+            if (build_block->columns() > row_desc.num_slots()) {
+                build_block->erase(row_desc.num_slots());
+            }
+            RETURN_IF_ERROR(_partition_block(state, build_block.get(), 1, 
build_block->rows()));
+        }
+    }
+
     _spilling_streams_count = _shared_state->partitioned_build_blocks.size();
     for (size_t i = 0; i != _shared_state->partitioned_build_blocks.size(); 
++i) {
         vectorized::SpillStreamSPtr& spilling_stream = 
_shared_state->spilled_streams[i];
@@ -124,6 +185,45 @@ Status 
PartitionedHashJoinSinkLocalState::revoke_memory(RuntimeState* state) {
     return Status::OK();
 }
 
+Status PartitionedHashJoinSinkLocalState::_partition_block(RuntimeState* state,
+                                                           vectorized::Block* 
in_block,
+                                                           size_t begin, 
size_t end) {
+    const auto rows = in_block->rows();
+    if (!rows) {
+        return Status::OK();
+    }
+    {
+        /// TODO: DO NOT execute build exprs twice(when partition and building 
hash table)
+        SCOPED_TIMER(_partition_timer);
+        RETURN_IF_ERROR(_partitioner->do_partitioning(state, in_block, 
_mem_tracker.get()));
+    }
+
+    auto& p = _parent->cast<PartitionedHashJoinSinkOperatorX>();
+    SCOPED_TIMER(_partition_shuffle_timer);
+    auto* channel_ids = 
reinterpret_cast<uint64_t*>(_partitioner->get_channel_ids());
+    std::vector<uint32_t> partition_indexes[p._partition_count];
+    for (uint32_t i = 0; i != rows; ++i) {
+        partition_indexes[channel_ids[i]].emplace_back(i);
+    }
+
+    auto& partitioned_blocks = _shared_state->partitioned_build_blocks;
+    for (uint32_t i = 0; i != p._partition_count; ++i) {
+        const auto count = partition_indexes[i].size();
+        if (UNLIKELY(count == 0)) {
+            continue;
+        }
+
+        if (UNLIKELY(!partitioned_blocks[i])) {
+            partitioned_blocks[i] =
+                    
vectorized::MutableBlock::create_unique(in_block->clone_empty());
+        }
+        partitioned_blocks[i]->add_rows(in_block, &(partition_indexes[i][0]),
+                                        &(partition_indexes[i][count]));
+    }
+
+    return Status::OK();
+}
+
 void PartitionedHashJoinSinkLocalState::_spill_to_disk(
         uint32_t partition_index, const vectorized::SpillStreamSPtr& 
spilling_stream) {
     auto& partitioned_block = 
_shared_state->partitioned_build_blocks[partition_index];
@@ -183,10 +283,53 @@ Status PartitionedHashJoinSinkOperatorX::init(const 
TPlanNode& tnode, RuntimeSta
 }
 
 Status PartitionedHashJoinSinkOperatorX::prepare(RuntimeState* state) {
-    return Status::OK();
+    RETURN_IF_ERROR(_inner_sink_operator->set_child(_child_x));
+    return _inner_sink_operator->prepare(state);
 }
 
 Status PartitionedHashJoinSinkOperatorX::open(RuntimeState* state) {
+    return _inner_sink_operator->open(state);
+}
+
+Status 
PartitionedHashJoinSinkOperatorX::_setup_internal_operator(RuntimeState* state) 
{
+    auto& local_state = get_local_state(state);
+
+    local_state._shared_state->inner_runtime_state = 
RuntimeState::create_unique(
+            nullptr, state->fragment_instance_id(), state->query_id(), 
state->fragment_id(),
+            state->query_options(), TQueryGlobals {}, state->exec_env(), 
state->get_query_ctx());
+    local_state._shared_state->inner_runtime_state->set_task_execution_context(
+            state->get_task_execution_context().lock());
+    
local_state._shared_state->inner_runtime_state->set_be_number(state->be_number());
+
+    
local_state._shared_state->inner_runtime_state->set_desc_tbl(&state->desc_tbl());
+    
local_state._shared_state->inner_runtime_state->resize_op_id_to_local_state(-1);
+    
local_state._shared_state->inner_runtime_state->set_pipeline_x_runtime_filter_mgr(
+            state->local_runtime_filter_mgr());
+
+    local_state._shared_state->inner_shared_state = 
std::dynamic_pointer_cast<HashJoinSharedState>(
+            _inner_sink_operator->create_shared_state());
+    LocalSinkStateInfo info {0,  local_state._internal_runtime_profile.get(),
+                             -1, 
local_state._shared_state->inner_shared_state.get(),
+                             {}, {}};
+
+    RETURN_IF_ERROR(_inner_sink_operator->setup_local_state(
+            local_state._shared_state->inner_runtime_state.get(), info));
+    auto* sink_local_state = 
local_state._shared_state->inner_runtime_state->get_sink_local_state();
+    DCHECK(sink_local_state != nullptr);
+
+    LocalStateInfo state_info {local_state._internal_runtime_profile.get(),
+                               {},
+                               
local_state._shared_state->inner_shared_state.get(),
+                               {},
+                               0};
+
+    RETURN_IF_ERROR(_inner_probe_operator->setup_local_state(
+            local_state._shared_state->inner_runtime_state.get(), state_info));
+    auto* probe_local_state = 
local_state._shared_state->inner_runtime_state->get_local_state(
+            _inner_probe_operator->operator_id());
+    DCHECK(probe_local_state != nullptr);
+    RETURN_IF_ERROR(probe_local_state->open(state));
+    RETURN_IF_ERROR(sink_local_state->open(state));
     return Status::OK();
 }
 
@@ -204,44 +347,38 @@ Status 
PartitionedHashJoinSinkOperatorX::sink(RuntimeState* state, vectorized::B
 
     const auto rows = in_block->rows();
 
-    if (rows > 0) {
-        COUNTER_UPDATE(local_state.rows_input_counter(), 
(int64_t)in_block->rows());
-        /// TODO: DO NOT execute build exprs twice(when partition and building 
hash table)
-        {
-            SCOPED_TIMER(local_state._partition_timer);
-            RETURN_IF_ERROR(local_state._partitioner->do_partitioning(
-                    state, in_block, local_state._mem_tracker.get()));
-        }
+    const auto need_to_spill = local_state._shared_state->need_to_spill;
+    if (rows == 0) {
+        if (eos) {
+            LOG(INFO) << "hash join sink " << id() << " sink eos, 
set_ready_to_read"
+                      << ", task id: " << state->task_id();
 
-        SCOPED_TIMER(local_state._partition_shuffle_timer);
-        auto* channel_ids =
-                
reinterpret_cast<uint64_t*>(local_state._partitioner->get_channel_ids());
-        std::vector<uint32_t> partition_indexes[_partition_count];
-        for (uint32_t i = 0; i != rows; ++i) {
-            partition_indexes[channel_ids[i]].emplace_back(i);
+            if (!need_to_spill) {
+                if (UNLIKELY(!local_state._shared_state->inner_runtime_state)) 
{
+                    RETURN_IF_ERROR(_setup_internal_operator(state));
+                }
+                RETURN_IF_ERROR(_inner_sink_operator->sink(
+                        local_state._shared_state->inner_runtime_state.get(), 
in_block, eos));
+            }
+            local_state._dependency->set_ready_to_read();
         }
+        return Status::OK();
+    }
 
-        auto& partitioned_blocks = 
local_state._shared_state->partitioned_build_blocks;
-        for (uint32_t i = 0; i != _partition_count; ++i) {
-            const auto count = partition_indexes[i].size();
-            if (UNLIKELY(count == 0)) {
-                continue;
-            }
+    COUNTER_UPDATE(local_state.rows_input_counter(), 
(int64_t)in_block->rows());
+    if (need_to_spill) {
+        RETURN_IF_ERROR(local_state._partition_block(state, in_block, 0, 
rows));
 
-            if (UNLIKELY(!partitioned_blocks[i])) {
-                partitioned_blocks[i] =
-                        
vectorized::MutableBlock::create_unique(in_block->clone_empty());
-            }
-            partitioned_blocks[i]->add_rows(in_block, 
&(partition_indexes[i][0]),
-                                            &(partition_indexes[i][count]));
+        const auto revocable_size = revocable_mem_size(state);
+        if (revocable_size > state->min_revocable_mem()) {
+            return local_state.revoke_memory(state);
         }
-
-        if (local_state._shared_state->need_to_spill) {
-            const auto revocable_size = revocable_mem_size(state);
-            if (revocable_size > state->min_revocable_mem()) {
-                return local_state.revoke_memory(state);
-            }
+    } else {
+        if (UNLIKELY(!local_state._shared_state->inner_runtime_state)) {
+            RETURN_IF_ERROR(_setup_internal_operator(state));
         }
+        RETURN_IF_ERROR(_inner_sink_operator->sink(
+                local_state._shared_state->inner_runtime_state.get(), 
in_block, eos));
     }
 
     if (eos) {
@@ -256,19 +393,7 @@ Status 
PartitionedHashJoinSinkOperatorX::sink(RuntimeState* state, vectorized::B
 size_t PartitionedHashJoinSinkOperatorX::revocable_mem_size(RuntimeState* 
state) const {
     auto& local_state = get_local_state(state);
     SCOPED_TIMER(local_state.exec_time_counter());
-    auto& partitioned_blocks = 
local_state._shared_state->partitioned_build_blocks;
-
-    size_t mem_size = 0;
-    for (uint32_t i = 0; i != _partition_count; ++i) {
-        auto& block = partitioned_blocks[i];
-        if (block) {
-            auto block_bytes = block->allocated_bytes();
-            if (block_bytes >= 
vectorized::SpillStream::MIN_SPILL_WRITE_BATCH_MEM) {
-                mem_size += block_bytes;
-            }
-        }
-    }
-    return mem_size;
+    return local_state.revocable_mem_size(state);
 }
 
 Status PartitionedHashJoinSinkOperatorX::revoke_memory(RuntimeState* state) {
diff --git a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.h 
b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.h
index 96e751360d4..3a03c2fc724 100644
--- a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.h
+++ b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.h
@@ -48,6 +48,7 @@ public:
     Status open(RuntimeState* state) override;
     Status close(RuntimeState* state, Status exec_status) override;
     Status revoke_memory(RuntimeState* state);
+    size_t revocable_mem_size(RuntimeState* state) const;
 
 protected:
     PartitionedHashJoinSinkLocalState(DataSinkOperatorXBase* parent, 
RuntimeState* state)
@@ -56,6 +57,9 @@ protected:
     void _spill_to_disk(uint32_t partition_index,
                         const vectorized::SpillStreamSPtr& spilling_stream);
 
+    Status _partition_block(RuntimeState* state, vectorized::Block* in_block, 
size_t begin,
+                            size_t end);
+
     friend class PartitionedHashJoinSinkOperatorX;
 
     std::atomic_int _spilling_streams_count {0};
@@ -74,6 +78,8 @@ protected:
 
     std::unique_ptr<PartitionerType> _partitioner;
 
+    std::unique_ptr<RuntimeProfile> _internal_runtime_profile;
+
     RuntimeProfile::Counter* _partition_timer = nullptr;
     RuntimeProfile::Counter* _partition_shuffle_timer = nullptr;
     RuntimeProfile::Counter* _spill_build_timer = nullptr;
@@ -121,13 +127,24 @@ public:
         return _join_distribution == TJoinDistributionType::PARTITIONED;
     }
 
+    void set_inner_operators(const 
std::shared_ptr<HashJoinBuildSinkOperatorX>& sink_operator,
+                             const std::shared_ptr<HashJoinProbeOperatorX>& 
probe_operator) {
+        _inner_sink_operator = sink_operator;
+        _inner_probe_operator = probe_operator;
+    }
+
 private:
     friend class PartitionedHashJoinSinkLocalState;
 
+    Status _setup_internal_operator(RuntimeState* state);
+
     const TJoinDistributionType::type _join_distribution;
 
     std::vector<TExpr> _build_exprs;
 
+    std::shared_ptr<HashJoinBuildSinkOperatorX> _inner_sink_operator;
+    std::shared_ptr<HashJoinProbeOperatorX> _inner_probe_operator;
+
     const std::vector<TExpr> _distribution_partition_exprs;
     const TPlanNode _tnode;
     const DescriptorTbl _descriptor_tbl;
diff --git a/be/src/pipeline/pipeline_x/dependency.h 
b/be/src/pipeline/pipeline_x/dependency.h
index 3ea096a81d6..f195322e3f5 100644
--- a/be/src/pipeline/pipeline_x/dependency.h
+++ b/be/src/pipeline/pipeline_x/dependency.h
@@ -599,6 +599,8 @@ struct PartitionedHashJoinSharedState
           public std::enable_shared_from_this<PartitionedHashJoinSharedState> {
     ENABLE_FACTORY_CREATOR(PartitionedHashJoinSharedState)
 
+    std::unique_ptr<RuntimeState> inner_runtime_state;
+    std::shared_ptr<HashJoinSharedState> inner_shared_state;
     std::vector<std::unique_ptr<vectorized::MutableBlock>> 
partitioned_build_blocks;
     std::vector<vectorized::SpillStreamSPtr> spilled_streams;
     bool need_to_spill = false;
diff --git a/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp 
b/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp
index f888500cef6..ce2d4a50748 100644
--- a/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp
+++ b/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp
@@ -1040,9 +1040,22 @@ Status 
PipelineXFragmentContext::_create_operator(ObjectPool* pool, const TPlanN
                                        tnode.hash_join_node.is_broadcast_join;
         const auto enable_join_spill = _runtime_state->enable_join_spill();
         if (enable_join_spill && !is_broadcast_join) {
+            auto tnode_ = tnode;
+            /// TODO: support rf in partitioned hash join
+            tnode_.runtime_filters.clear();
             const uint32_t partition_count = 32;
-            op.reset(new PartitionedHashJoinProbeOperatorX(pool, tnode, 
next_operator_id(), descs,
-                                                           partition_count));
+            auto inner_probe_operator =
+                    std::make_shared<HashJoinProbeOperatorX>(pool, tnode_, 0, 
descs);
+            auto inner_sink_operator = 
std::make_shared<HashJoinBuildSinkOperatorX>(
+                    pool, 0, tnode_, descs, _need_local_merge);
+
+            RETURN_IF_ERROR(inner_probe_operator->init(tnode_, 
_runtime_state.get()));
+            RETURN_IF_ERROR(inner_sink_operator->init(tnode_, 
_runtime_state.get()));
+
+            auto probe_operator = 
std::make_shared<PartitionedHashJoinProbeOperatorX>(
+                    pool, tnode_, next_operator_id(), descs, partition_count);
+            probe_operator->set_inner_operators(inner_sink_operator, 
inner_probe_operator);
+            op = std::move(probe_operator);
             RETURN_IF_ERROR(cur_pipe->add_operator(op));
 
             const auto downstream_pipeline_id = cur_pipe->id();
@@ -1052,13 +1065,14 @@ Status 
PipelineXFragmentContext::_create_operator(ObjectPool* pool, const TPlanN
             PipelinePtr build_side_pipe = add_pipeline(cur_pipe);
             _dag[downstream_pipeline_id].push_back(build_side_pipe->id());
 
-            DataSinkOperatorXPtr sink;
-            sink.reset(new PartitionedHashJoinSinkOperatorX(pool, 
next_sink_operator_id(), tnode,
-                                                            descs, 
_need_local_merge,
-                                                            partition_count));
+            auto sink_operator = 
std::make_shared<PartitionedHashJoinSinkOperatorX>(
+                    pool, next_sink_operator_id(), tnode_, descs, 
_need_local_merge,
+                    partition_count);
+            sink_operator->set_inner_operators(inner_sink_operator, 
inner_probe_operator);
+            DataSinkOperatorXPtr sink = std::move(sink_operator);
             sink->set_dests_id({op->operator_id()});
             RETURN_IF_ERROR(build_side_pipe->set_sink(sink));
-            RETURN_IF_ERROR(build_side_pipe->sink_x()->init(tnode, 
_runtime_state.get()));
+            RETURN_IF_ERROR(build_side_pipe->sink_x()->init(tnode_, 
_runtime_state.get()));
 
             _pipeline_parent_map.push(op->node_id(), cur_pipe);
             _pipeline_parent_map.push(op->node_id(), build_side_pipe);
diff --git a/be/src/vec/spill/spill_writer.cpp 
b/be/src/vec/spill/spill_writer.cpp
index dc45752b813..ffd85c852a6 100644
--- a/be/src/vec/spill/spill_writer.cpp
+++ b/be/src/vec/spill/spill_writer.cpp
@@ -107,9 +107,10 @@ Status SpillWriter::_write_internal(const Block& block, 
size_t& written_bytes) {
         {
             PBlock pblock;
             SCOPED_TIMER(serialize_timer_);
-            status = 
block.serialize(BeExecVersionManager::get_newest_version(), &pblock,
-                                     &uncompressed_bytes, &compressed_bytes,
-                                     segment_v2::CompressionTypePB::LZ4);
+            status = block.serialize(
+                    BeExecVersionManager::get_newest_version(), &pblock, 
&uncompressed_bytes,
+                    &compressed_bytes,
+                    segment_v2::CompressionTypePB::ZSTD); // ZSTD for better 
compression ratio
             RETURN_IF_ERROR(status);
             if (!pblock.SerializeToString(&buff)) {
                 return Status::Error<ErrorCode::SERIALIZE_PROTOBUF_ERROR>(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to