This is an automated email from the ASF dual-hosted git repository.

dataroaring pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 24fce53b34485e688f5252b378d958e8aed57c49
Author: Mryange <[email protected]>
AuthorDate: Mon Aug 26 21:40:58 2024 +0800

    [fix](local shuffle) Not channel_id causes a local merge infinite loop. 
(#39725)
---
 be/src/pipeline/dependency.h                       |  4 +--
 be/src/pipeline/local_exchange/local_exchanger.cpp | 30 ++++++++++++----------
 be/src/pipeline/local_exchange/local_exchanger.h   |  9 ++++---
 3 files changed, 23 insertions(+), 20 deletions(-)

diff --git a/be/src/pipeline/dependency.h b/be/src/pipeline/dependency.h
index 8def7be6147..70a4b7d0b47 100644
--- a/be/src/pipeline/dependency.h
+++ b/be/src/pipeline/dependency.h
@@ -856,13 +856,13 @@ public:
 
     void sub_mem_usage(int channel_id, size_t delta) { 
mem_trackers[channel_id]->release(delta); }
 
-    virtual void add_total_mem_usage(size_t delta, int channel_id = 0) {
+    virtual void add_total_mem_usage(size_t delta, int channel_id) {
         if (mem_usage.fetch_add(delta) + delta > 
config::local_exchange_buffer_mem_limit) {
             sink_deps.front()->block();
         }
     }
 
-    virtual void sub_total_mem_usage(size_t delta, int channel_id = 0) {
+    virtual void sub_total_mem_usage(size_t delta, int channel_id) {
         auto prev_usage = mem_usage.fetch_sub(delta);
         DCHECK_GE(prev_usage - delta, 0) << "prev_usage: " << prev_usage << " 
delta: " << delta
                                          << " channel_id: " << channel_id;
diff --git a/be/src/pipeline/local_exchange/local_exchanger.cpp 
b/be/src/pipeline/local_exchange/local_exchanger.cpp
index 39e13802798..1bcd9f34ba8 100644
--- a/be/src/pipeline/local_exchange/local_exchanger.cpp
+++ b/be/src/pipeline/local_exchange/local_exchanger.cpp
@@ -54,9 +54,9 @@ void Exchanger<BlockType>::_enqueue_data_and_set_ready(int 
channel_id,
         // just unref the block.
         if constexpr (std::is_same_v<PartitionedBlock, BlockType> ||
                       std::is_same_v<BroadcastBlock, BlockType>) {
-            block.first->unref(local_state._shared_state, allocated_bytes);
+            block.first->unref(local_state._shared_state, allocated_bytes, 
channel_id);
         } else {
-            block->unref(local_state._shared_state, allocated_bytes);
+            block->unref(local_state._shared_state, allocated_bytes, 
channel_id);
             DCHECK_EQ(block->ref_value(), 0);
         }
     }
@@ -83,7 +83,7 @@ bool 
Exchanger<BlockType>::_dequeue_data(LocalExchangeSourceLocalState& local_st
             local_state._shared_state->sub_mem_usage(channel_id,
                                                      
block->data_block.allocated_bytes());
             data_block->swap(block->data_block);
-            block->unref(local_state._shared_state, 
data_block->allocated_bytes());
+            block->unref(local_state._shared_state, 
data_block->allocated_bytes(), channel_id);
             DCHECK_EQ(block->ref_value(), 0);
         }
         return true;
@@ -100,7 +100,7 @@ bool 
Exchanger<BlockType>::_dequeue_data(LocalExchangeSourceLocalState& local_st
                 local_state._shared_state->sub_mem_usage(channel_id,
                                                          
block->data_block.allocated_bytes());
                 data_block->swap(block->data_block);
-                block->unref(local_state._shared_state, 
data_block->allocated_bytes());
+                block->unref(local_state._shared_state, 
data_block->allocated_bytes(), channel_id);
                 DCHECK_EQ(block->ref_value(), 0);
             }
             return true;
@@ -137,7 +137,7 @@ void ShuffleExchanger::close(LocalExchangeSourceLocalState& 
local_state) {
     vectorized::Block block;
     _data_queue[local_state._channel_id].set_eos();
     while (_dequeue_data(local_state, partitioned_block, &eos, &block)) {
-        partitioned_block.first->unref(local_state._shared_state);
+        partitioned_block.first->unref(local_state._shared_state, 
local_state._channel_id);
     }
 }
 
@@ -153,7 +153,7 @@ Status ShuffleExchanger::get_block(RuntimeState* state, 
vectorized::Block* block
             auto block_wrapper = partitioned_block.first;
             RETURN_IF_ERROR(mutable_block.add_rows(&block_wrapper->data_block, 
offset_start,
                                                    offset_start + 
partitioned_block.second.length));
-            block_wrapper->unref(local_state._shared_state);
+            block_wrapper->unref(local_state._shared_state, 
local_state._channel_id);
         } while (mutable_block.rows() < state->batch_size() && !*eos &&
                  _dequeue_data(local_state, partitioned_block, eos, block));
         return Status::OK();
@@ -200,7 +200,8 @@ Status ShuffleExchanger::_split_rows(RuntimeState* state, 
const uint32_t* __rest
     if (new_block_wrapper->data_block.empty()) {
         return Status::OK();
     }
-    
local_state._shared_state->add_total_mem_usage(new_block_wrapper->data_block.allocated_bytes());
+    
local_state._shared_state->add_total_mem_usage(new_block_wrapper->data_block.allocated_bytes(),
+                                                   local_state._channel_id);
     auto bucket_seq_to_instance_idx =
             
local_state._parent->cast<LocalExchangeSinkOperatorX>()._bucket_seq_to_instance_idx;
     if (get_type() == ExchangeType::HASH_SHUFFLE) {
@@ -222,7 +223,7 @@ Status ShuffleExchanger::_split_rows(RuntimeState* state, 
const uint32_t* __rest
                 _enqueue_data_and_set_ready(it.second, local_state,
                                             {new_block_wrapper, {row_idx, 
start, size}});
             } else {
-                new_block_wrapper->unref(local_state._shared_state);
+                new_block_wrapper->unref(local_state._shared_state, 
local_state._channel_id);
             }
         }
     } else if (_num_senders != _num_sources || 
_ignore_source_data_distribution) {
@@ -235,7 +236,7 @@ Status ShuffleExchanger::_split_rows(RuntimeState* state, 
const uint32_t* __rest
                 _enqueue_data_and_set_ready(i % _num_sources, local_state,
                                             {new_block_wrapper, {row_idx, 
start, size}});
             } else {
-                new_block_wrapper->unref(local_state._shared_state);
+                new_block_wrapper->unref(local_state._shared_state, 
local_state._channel_id);
             }
         }
     } else if (bucket_seq_to_instance_idx.empty()) {
@@ -256,7 +257,7 @@ Status ShuffleExchanger::_split_rows(RuntimeState* state, 
const uint32_t* __rest
                 _enqueue_data_and_set_ready(it.second, local_state,
                                             {new_block_wrapper, {row_idx, 
start, size}});
             } else {
-                new_block_wrapper->unref(local_state._shared_state);
+                new_block_wrapper->unref(local_state._shared_state, 
local_state._channel_id);
             }
         }
     } else {
@@ -268,7 +269,7 @@ Status ShuffleExchanger::_split_rows(RuntimeState* state, 
const uint32_t* __rest
                 _enqueue_data_and_set_ready(bucket_seq_to_instance_idx[i], 
local_state,
                                             {new_block_wrapper, {row_idx, 
start, size}});
             } else {
-                new_block_wrapper->unref(local_state._shared_state);
+                new_block_wrapper->unref(local_state._shared_state, 
local_state._channel_id);
             }
         }
     }
@@ -443,7 +444,8 @@ Status BroadcastExchanger::sink(RuntimeState* state, 
vectorized::Block* in_block
     }
     new_block.swap(*in_block);
     auto wrapper = BlockWrapper::create_shared(std::move(new_block));
-    
local_state._shared_state->add_total_mem_usage(wrapper->data_block.allocated_bytes());
+    
local_state._shared_state->add_total_mem_usage(wrapper->data_block.allocated_bytes(),
+                                                   local_state._channel_id);
     wrapper->ref(_num_partitions);
     for (size_t i = 0; i < _num_partitions; i++) {
         _enqueue_data_and_set_ready(i, local_state, {wrapper, {0, 
wrapper->data_block.rows()}});
@@ -458,7 +460,7 @@ void 
BroadcastExchanger::close(LocalExchangeSourceLocalState& local_state) {
     vectorized::Block block;
     _data_queue[local_state._channel_id].set_eos();
     while (_dequeue_data(local_state, partitioned_block, &eos, &block)) {
-        partitioned_block.first->unref(local_state._shared_state);
+        partitioned_block.first->unref(local_state._shared_state, 
local_state._channel_id);
     }
 }
 
@@ -475,7 +477,7 @@ Status BroadcastExchanger::get_block(RuntimeState* state, 
vectorized::Block* blo
         RETURN_IF_ERROR(mutable_block.add_rows(&block_wrapper->data_block,
                                                
partitioned_block.second.offset_start,
                                                
partitioned_block.second.length));
-        block_wrapper->unref(local_state._shared_state);
+        block_wrapper->unref(local_state._shared_state, 
local_state._channel_id);
     }
 
     return Status::OK();
diff --git a/be/src/pipeline/local_exchange/local_exchanger.h 
b/be/src/pipeline/local_exchange/local_exchanger.h
index fe978a3cbdc..01b55816ba8 100644
--- a/be/src/pipeline/local_exchange/local_exchanger.h
+++ b/be/src/pipeline/local_exchange/local_exchanger.h
@@ -176,10 +176,10 @@ struct BlockWrapper {
     BlockWrapper(vectorized::Block&& data_block_) : 
data_block(std::move(data_block_)) {}
     ~BlockWrapper() { DCHECK_EQ(ref_count.load(), 0); }
     void ref(int delta) { ref_count += delta; }
-    void unref(LocalExchangeSharedState* shared_state, size_t allocated_bytes) 
{
+    void unref(LocalExchangeSharedState* shared_state, size_t allocated_bytes, 
int channel_id) {
         if (ref_count.fetch_sub(1) == 1) {
             DCHECK_GT(allocated_bytes, 0);
-            shared_state->sub_total_mem_usage(allocated_bytes);
+            shared_state->sub_total_mem_usage(allocated_bytes, channel_id);
             if (shared_state->exchanger->_free_block_limit == 0 ||
                 shared_state->exchanger->_free_blocks.size_approx() <
                         shared_state->exchanger->_free_block_limit *
@@ -189,8 +189,9 @@ struct BlockWrapper {
             }
         }
     }
-    void unref(LocalExchangeSharedState* shared_state) {
-        unref(shared_state, data_block.allocated_bytes());
+
+    void unref(LocalExchangeSharedState* shared_state, int channel_id) {
+        unref(shared_state, data_block.allocated_bytes(), channel_id);
     }
     int ref_value() const { return ref_count.load(); }
     std::atomic<int> ref_count = 0;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to