This is an automated email from the ASF dual-hosted git repository.

dataroaring pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/doris.git

commit d117ed7191ffb4754b594a486e99987f4fab7f65
Author: Gabriel <[email protected]>
AuthorDate: Thu Aug 29 16:51:45 2024 +0800

    [fix](local shuffle) Fix hash shuffle local exchanger (#40036)
---
 be/src/pipeline/exec/aggregation_sink_operator.h   |  3 +-
 be/src/pipeline/exec/analytic_sink_operator.h      |  5 +-
 .../exec/distinct_streaming_aggregation_operator.h |  5 +-
 be/src/pipeline/exec/hashjoin_build_sink.h         |  3 ++
 be/src/pipeline/exec/hashjoin_probe_operator.h     |  3 ++
 be/src/pipeline/exec/operator.h                    |  6 +++
 .../exec/partitioned_aggregation_sink_operator.h   |  3 ++
 .../exec/partitioned_hash_join_probe_operator.h    |  3 ++
 .../exec/partitioned_hash_join_sink_operator.h     |  3 ++
 be/src/pipeline/exec/set_probe_sink_operator.h     |  2 +
 be/src/pipeline/exec/set_sink_operator.h           |  1 +
 be/src/pipeline/exec/sort_sink_operator.h          |  3 +-
 .../local_exchange_sink_operator.cpp               | 14 ++---
 .../local_exchange/local_exchange_sink_operator.h  |  2 +-
 be/src/pipeline/local_exchange/local_exchanger.cpp | 22 +-------
 be/src/pipeline/pipeline_fragment_context.cpp      | 63 ++++++++++++++++++----
 be/src/pipeline/pipeline_fragment_context.h        |  5 +-
 .../nereids_p0/join/test_join_local_shuffle.groovy |  6 ++-
 18 files changed, 106 insertions(+), 46 deletions(-)

diff --git a/be/src/pipeline/exec/aggregation_sink_operator.h 
b/be/src/pipeline/exec/aggregation_sink_operator.h
index 579b9eda1a6..f7b225311a3 100644
--- a/be/src/pipeline/exec/aggregation_sink_operator.h
+++ b/be/src/pipeline/exec/aggregation_sink_operator.h
@@ -149,11 +149,12 @@ public:
                            ? DataDistribution(ExchangeType::PASSTHROUGH)
                            : 
DataSinkOperatorX<AggSinkLocalState>::required_data_distribution();
         }
-        return _is_colocate && _require_bucket_distribution
+        return _is_colocate && _require_bucket_distribution && 
!_followed_by_shuffled_join
                        ? DataDistribution(ExchangeType::BUCKET_HASH_SHUFFLE, 
_partition_exprs)
                        : DataDistribution(ExchangeType::HASH_SHUFFLE, 
_partition_exprs);
     }
     bool require_data_distribution() const override { return _is_colocate; }
+    bool require_shuffled_data_distribution() const override { return 
!_probe_expr_ctxs.empty(); }
     size_t get_revocable_mem_size(RuntimeState* state) const;
 
     AggregatedDataVariants* get_agg_data(RuntimeState* state) {
diff --git a/be/src/pipeline/exec/analytic_sink_operator.h 
b/be/src/pipeline/exec/analytic_sink_operator.h
index 47080b82380..6d713996b9c 100644
--- a/be/src/pipeline/exec/analytic_sink_operator.h
+++ b/be/src/pipeline/exec/analytic_sink_operator.h
@@ -82,7 +82,7 @@ public:
         if (_partition_by_eq_expr_ctxs.empty()) {
             return {ExchangeType::PASSTHROUGH};
         } else if (_order_by_eq_expr_ctxs.empty()) {
-            return _is_colocate && _require_bucket_distribution
+            return _is_colocate && _require_bucket_distribution && 
!_followed_by_shuffled_join
                            ? 
DataDistribution(ExchangeType::BUCKET_HASH_SHUFFLE, _partition_exprs)
                            : DataDistribution(ExchangeType::HASH_SHUFFLE, 
_partition_exprs);
         }
@@ -90,6 +90,9 @@ public:
     }
 
     bool require_data_distribution() const override { return true; }
+    bool require_shuffled_data_distribution() const override {
+        return !_partition_by_eq_expr_ctxs.empty() && 
_order_by_eq_expr_ctxs.empty();
+    }
 
 private:
     Status _insert_range_column(vectorized::Block* block, const 
vectorized::VExprContextSPtr& expr,
diff --git a/be/src/pipeline/exec/distinct_streaming_aggregation_operator.h 
b/be/src/pipeline/exec/distinct_streaming_aggregation_operator.h
index d6ff5fde0c5..8ec1d18fd9e 100644
--- a/be/src/pipeline/exec/distinct_streaming_aggregation_operator.h
+++ b/be/src/pipeline/exec/distinct_streaming_aggregation_operator.h
@@ -106,7 +106,7 @@ public:
 
     DataDistribution required_data_distribution() const override {
         if (_needs_finalize || (!_probe_expr_ctxs.empty() && 
!_is_streaming_preagg)) {
-            return _is_colocate && _require_bucket_distribution
+            return _is_colocate && _require_bucket_distribution && 
!_followed_by_shuffled_join
                            ? 
DataDistribution(ExchangeType::BUCKET_HASH_SHUFFLE, _partition_exprs)
                            : DataDistribution(ExchangeType::HASH_SHUFFLE, 
_partition_exprs);
         }
@@ -114,6 +114,9 @@ public:
     }
 
     bool require_data_distribution() const override { return _is_colocate; }
+    bool require_shuffled_data_distribution() const override {
+        return _needs_finalize || (!_probe_expr_ctxs.empty() && 
!_is_streaming_preagg);
+    }
 
 private:
     friend class DistinctStreamingAggLocalState;
diff --git a/be/src/pipeline/exec/hashjoin_build_sink.h 
b/be/src/pipeline/exec/hashjoin_build_sink.h
index f6681548dbe..4dba04559b4 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.h
+++ b/be/src/pipeline/exec/hashjoin_build_sink.h
@@ -143,6 +143,9 @@ public:
                        : DataDistribution(ExchangeType::HASH_SHUFFLE, 
_partition_exprs);
     }
 
+    bool require_shuffled_data_distribution() const override {
+        return _join_op != TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN && 
!_is_broadcast_join;
+    }
     bool is_shuffled_hash_join() const override {
         return _join_distribution == TJoinDistributionType::PARTITIONED;
     }
diff --git a/be/src/pipeline/exec/hashjoin_probe_operator.h 
b/be/src/pipeline/exec/hashjoin_probe_operator.h
index c1a53c0c1f3..69ab0808be4 100644
--- a/be/src/pipeline/exec/hashjoin_probe_operator.h
+++ b/be/src/pipeline/exec/hashjoin_probe_operator.h
@@ -153,6 +153,9 @@ public:
                                   : 
DataDistribution(ExchangeType::HASH_SHUFFLE, _partition_exprs));
     }
 
+    bool require_shuffled_data_distribution() const override {
+        return _join_op != TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN && 
!_is_broadcast_join;
+    }
     bool is_shuffled_hash_join() const override {
         return _join_distribution == TJoinDistributionType::PARTITIONED;
     }
diff --git a/be/src/pipeline/exec/operator.h b/be/src/pipeline/exec/operator.h
index 9d549690461..abed7fb446a 100644
--- a/be/src/pipeline/exec/operator.h
+++ b/be/src/pipeline/exec/operator.h
@@ -114,11 +114,17 @@ public:
     virtual Status revoke_memory(RuntimeState* state) { return Status::OK(); }
     [[nodiscard]] virtual bool require_data_distribution() const { return 
false; }
     OperatorXPtr child_x() { return _child_x; }
+    [[nodiscard]] bool followed_by_shuffled_join() const { return 
_followed_by_shuffled_join; }
+    void set_followed_by_shuffled_join(bool followed_by_shuffled_join) {
+        _followed_by_shuffled_join = followed_by_shuffled_join;
+    }
+    [[nodiscard]] virtual bool require_shuffled_data_distribution() const { 
return false; }
 
 protected:
     OperatorXPtr _child_x = nullptr;
 
     bool _is_closed;
+    bool _followed_by_shuffled_join = false;
 };
 
 class PipelineXLocalStateBase {
diff --git a/be/src/pipeline/exec/partitioned_aggregation_sink_operator.h 
b/be/src/pipeline/exec/partitioned_aggregation_sink_operator.h
index e5e44498ec0..9282df073cb 100644
--- a/be/src/pipeline/exec/partitioned_aggregation_sink_operator.h
+++ b/be/src/pipeline/exec/partitioned_aggregation_sink_operator.h
@@ -312,6 +312,9 @@ public:
     bool require_data_distribution() const override {
         return _agg_sink_operator->require_data_distribution();
     }
+    bool require_shuffled_data_distribution() const override {
+        return _agg_sink_operator->require_shuffled_data_distribution();
+    }
 
     Status set_child(OperatorXPtr child) override {
         
RETURN_IF_ERROR(DataSinkOperatorX<PartitionedAggSinkLocalState>::set_child(child));
diff --git a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.h 
b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.h
index 6ee718a3354..a63ddb3e69d 100644
--- a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.h
+++ b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.h
@@ -166,6 +166,9 @@ public:
                                            _distribution_partition_exprs));
     }
 
+    bool require_shuffled_data_distribution() const override {
+        return _join_op != TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN;
+    }
     bool is_shuffled_hash_join() const override {
         return _join_distribution == TJoinDistributionType::PARTITIONED;
     }
diff --git a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.h 
b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.h
index 1592c29cdb0..252c53be12d 100644
--- a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.h
+++ b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.h
@@ -116,6 +116,9 @@ public:
                                           _distribution_partition_exprs);
     }
 
+    bool require_shuffled_data_distribution() const override {
+        return _join_op != TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN;
+    }
     bool is_shuffled_hash_join() const override {
         return _join_distribution == TJoinDistributionType::PARTITIONED;
     }
diff --git a/be/src/pipeline/exec/set_probe_sink_operator.h 
b/be/src/pipeline/exec/set_probe_sink_operator.h
index 93a862fa1cb..f21d5842581 100644
--- a/be/src/pipeline/exec/set_probe_sink_operator.h
+++ b/be/src/pipeline/exec/set_probe_sink_operator.h
@@ -98,6 +98,8 @@ public:
                             : DataDistribution(ExchangeType::HASH_SHUFFLE, 
_partition_exprs);
     }
 
+    bool require_shuffled_data_distribution() const override { return true; }
+
     std::shared_ptr<BasicSharedState> create_shared_state() const override { 
return nullptr; }
 
 private:
diff --git a/be/src/pipeline/exec/set_sink_operator.h 
b/be/src/pipeline/exec/set_sink_operator.h
index 09a1fa09e7c..ac0757e4467 100644
--- a/be/src/pipeline/exec/set_sink_operator.h
+++ b/be/src/pipeline/exec/set_sink_operator.h
@@ -95,6 +95,7 @@ public:
         return _is_colocate ? 
DataDistribution(ExchangeType::BUCKET_HASH_SHUFFLE, _partition_exprs)
                             : DataDistribution(ExchangeType::HASH_SHUFFLE, 
_partition_exprs);
     }
+    bool require_shuffled_data_distribution() const override { return true; }
 
 private:
     template <class HashTableContext, bool is_intersected>
diff --git a/be/src/pipeline/exec/sort_sink_operator.h 
b/be/src/pipeline/exec/sort_sink_operator.h
index b842a56f2ad..3188bfe3990 100644
--- a/be/src/pipeline/exec/sort_sink_operator.h
+++ b/be/src/pipeline/exec/sort_sink_operator.h
@@ -64,7 +64,7 @@ public:
     Status sink(RuntimeState* state, vectorized::Block* in_block, bool eos) 
override;
     DataDistribution required_data_distribution() const override {
         if (_is_analytic_sort) {
-            return _is_colocate && _require_bucket_distribution
+            return _is_colocate && _require_bucket_distribution && 
!_followed_by_shuffled_join
                            ? 
DataDistribution(ExchangeType::BUCKET_HASH_SHUFFLE, _partition_exprs)
                            : DataDistribution(ExchangeType::HASH_SHUFFLE, 
_partition_exprs);
         } else if (_merge_by_exchange) {
@@ -73,6 +73,7 @@ public:
         }
         return 
DataSinkOperatorX<SortSinkLocalState>::required_data_distribution();
     }
+    bool require_shuffled_data_distribution() const override { return 
_is_analytic_sort; }
     bool require_data_distribution() const override { return _is_colocate; }
 
     size_t get_revocable_mem_size(RuntimeState* state) const;
diff --git a/be/src/pipeline/local_exchange/local_exchange_sink_operator.cpp 
b/be/src/pipeline/local_exchange/local_exchange_sink_operator.cpp
index f0a51696075..98b1a719a49 100644
--- a/be/src/pipeline/local_exchange/local_exchange_sink_operator.cpp
+++ b/be/src/pipeline/local_exchange/local_exchange_sink_operator.cpp
@@ -36,16 +36,16 @@ std::vector<Dependency*> 
LocalExchangeSinkLocalState::dependencies() const {
 }
 
 Status LocalExchangeSinkOperatorX::init(ExchangeType type, const int 
num_buckets,
-                                        const bool is_shuffled_hash_join,
+                                        const bool 
should_disable_bucket_shuffle,
                                         const std::map<int, int>& 
shuffle_idx_to_instance_idx) {
     _name = "LOCAL_EXCHANGE_SINK_OPERATOR (" + get_exchange_type_name(type) + 
")";
     _type = type;
-    if (_type == ExchangeType::HASH_SHUFFLE || _type == 
ExchangeType::BUCKET_HASH_SHUFFLE) {
+    if (_type == ExchangeType::HASH_SHUFFLE) {
         // For shuffle join, if data distribution has been broken by previous 
operator, we
         // should use a HASH_SHUFFLE local exchanger to shuffle data again. To 
be mentioned,
         // we should use map shuffle idx to instance idx because all instances 
will be
         // distributed to all BEs. Otherwise, we should use shuffle idx 
directly.
-        if (is_shuffled_hash_join) {
+        if (should_disable_bucket_shuffle) {
             std::for_each(shuffle_idx_to_instance_idx.begin(), 
shuffle_idx_to_instance_idx.end(),
                           [&](const auto& item) {
                               DCHECK(item.first != -1);
@@ -58,9 +58,11 @@ Status LocalExchangeSinkOperatorX::init(ExchangeType type, 
const int num_buckets
             }
         }
         _partitioner.reset(new 
vectorized::Crc32HashPartitioner<vectorized::ShuffleChannelIds>(
-                _type == ExchangeType::HASH_SHUFFLE || 
_bucket_seq_to_instance_idx.empty()
-                        ? _num_partitions
-                        : num_buckets));
+                _num_partitions));
+        RETURN_IF_ERROR(_partitioner->init(_texprs));
+    } else if (_type == ExchangeType::BUCKET_HASH_SHUFFLE) {
+        _partitioner.reset(
+                new 
vectorized::Crc32HashPartitioner<vectorized::ShuffleChannelIds>(num_buckets));
         RETURN_IF_ERROR(_partitioner->init(_texprs));
     }
     return Status::OK();
diff --git a/be/src/pipeline/local_exchange/local_exchange_sink_operator.h 
b/be/src/pipeline/local_exchange/local_exchange_sink_operator.h
index faa48d209f4..e0e7688307c 100644
--- a/be/src/pipeline/local_exchange/local_exchange_sink_operator.h
+++ b/be/src/pipeline/local_exchange/local_exchange_sink_operator.h
@@ -102,7 +102,7 @@ public:
         return Status::InternalError("{} should not init with TPlanNode", 
Base::_name);
     }
 
-    Status init(ExchangeType type, const int num_buckets, const bool 
is_shuffled_hash_join,
+    Status init(ExchangeType type, const int num_buckets, const bool 
should_disable_bucket_shuffle,
                 const std::map<int, int>& shuffle_idx_to_instance_idx) 
override;
 
     Status prepare(RuntimeState* state) override;
diff --git a/be/src/pipeline/local_exchange/local_exchanger.cpp 
b/be/src/pipeline/local_exchange/local_exchanger.cpp
index 1bcd9f34ba8..f4630f328bb 100644
--- a/be/src/pipeline/local_exchange/local_exchanger.cpp
+++ b/be/src/pipeline/local_exchange/local_exchanger.cpp
@@ -239,28 +239,8 @@ Status ShuffleExchanger::_split_rows(RuntimeState* state, 
const uint32_t* __rest
                 new_block_wrapper->unref(local_state._shared_state, 
local_state._channel_id);
             }
         }
-    } else if (bucket_seq_to_instance_idx.empty()) {
-        /**
-         * If type is `BUCKET_HASH_SHUFFLE` and `_bucket_seq_to_instance_idx` 
is empty, which
-         * means no scan operators is included in this fragment so we also 
need a `HASH_SHUFFLE` here.
-         */
-        const auto& map = 
local_state._parent->cast<LocalExchangeSinkOperatorX>()
-                                  ._shuffle_idx_to_instance_idx;
-        DCHECK(!map.empty());
-        new_block_wrapper->ref(map.size());
-        for (const auto& it : map) {
-            DCHECK(it.second >= 0 && it.second < _num_partitions)
-                    << it.first << " : " << it.second << " " << 
_num_partitions;
-            uint32_t start = local_state._partition_rows_histogram[it.first];
-            uint32_t size = local_state._partition_rows_histogram[it.first + 
1] - start;
-            if (size > 0) {
-                _enqueue_data_and_set_ready(it.second, local_state,
-                                            {new_block_wrapper, {row_idx, 
start, size}});
-            } else {
-                new_block_wrapper->unref(local_state._shared_state, 
local_state._channel_id);
-            }
-        }
     } else {
+        DCHECK(!bucket_seq_to_instance_idx.empty());
         new_block_wrapper->ref(_num_partitions);
         for (size_t i = 0; i < _num_partitions; i++) {
             uint32_t start = local_state._partition_rows_histogram[i];
diff --git a/be/src/pipeline/pipeline_fragment_context.cpp 
b/be/src/pipeline/pipeline_fragment_context.cpp
index a7c61c3a184..2a89f9be7b6 100644
--- a/be/src/pipeline/pipeline_fragment_context.cpp
+++ b/be/src/pipeline/pipeline_fragment_context.cpp
@@ -616,7 +616,7 @@ Status 
PipelineFragmentContext::_build_pipelines(ObjectPool* pool,
     int node_idx = 0;
 
     RETURN_IF_ERROR(_create_tree_helper(pool, request.fragment.plan.nodes, 
request, descs, nullptr,
-                                        &node_idx, root, cur_pipe, 0));
+                                        &node_idx, root, cur_pipe, 0, false));
 
     if (node_idx + 1 != request.fragment.plan.nodes.size()) {
         return Status::InternalError(
@@ -630,7 +630,8 @@ Status 
PipelineFragmentContext::_create_tree_helper(ObjectPool* pool,
                                                     const 
doris::TPipelineFragmentParams& request,
                                                     const DescriptorTbl& 
descs, OperatorXPtr parent,
                                                     int* node_idx, 
OperatorXPtr* root,
-                                                    PipelinePtr& cur_pipe, int 
child_idx) {
+                                                    PipelinePtr& cur_pipe, int 
child_idx,
+                                                    const bool 
followed_by_shuffled_join) {
     // propagate error case
     if (*node_idx >= tnodes.size()) {
         return Status::InternalError(
@@ -640,9 +641,11 @@ Status 
PipelineFragmentContext::_create_tree_helper(ObjectPool* pool,
     const TPlanNode& tnode = tnodes[*node_idx];
 
     int num_children = tnodes[*node_idx].num_children;
+    bool current_followed_by_shuffled_join = followed_by_shuffled_join;
     OperatorXPtr op = nullptr;
     RETURN_IF_ERROR(_create_operator(pool, tnodes[*node_idx], request, descs, 
op, cur_pipe,
-                                     parent == nullptr ? -1 : 
parent->node_id(), child_idx));
+                                     parent == nullptr ? -1 : 
parent->node_id(), child_idx,
+                                     followed_by_shuffled_join));
 
     // assert(parent != nullptr || (node_idx == 0 && root_expr != nullptr));
     if (parent != nullptr) {
@@ -651,12 +654,30 @@ Status 
PipelineFragmentContext::_create_tree_helper(ObjectPool* pool,
     } else {
         *root = op;
     }
+    /**
+     * `ExchangeType::HASH_SHUFFLE` should be used if an operator is followed 
by a shuffled hash join.
+     *
+     * For plan:
+     * LocalExchange(id=0) -> Aggregation(id=1) -> ShuffledHashJoin(id=2)
+     *                           Exchange(id=3) -> ShuffledHashJoinBuild(id=2)
+     * We must ensure data distribution of `LocalExchange(id=0)` is same as 
Exchange(id=3).
+     *
+     * If an operator's is followed by a local exchange without shuffle (e.g. 
passthrough), a
+     * shuffled local exchanger will be used before join so it is not followed 
by shuffle join.
+     */
+    auto require_shuffled_data_distribution =
+            cur_pipe->operator_xs().empty()
+                    ? cur_pipe->sink_x()->require_shuffled_data_distribution()
+                    : op->require_shuffled_data_distribution();
+    current_followed_by_shuffled_join =
+            (followed_by_shuffled_join || op->is_shuffled_hash_join()) &&
+            require_shuffled_data_distribution;
 
     // rely on that tnodes is preorder of the plan
     for (int i = 0; i < num_children; i++) {
         ++*node_idx;
         RETURN_IF_ERROR(_create_tree_helper(pool, tnodes, request, descs, op, 
node_idx, nullptr,
-                                            cur_pipe, i));
+                                            cur_pipe, i, 
current_followed_by_shuffled_join));
 
         // we are expecting a child, but have used all nodes
         // this means we have been given a bad tree and must fail
@@ -693,15 +714,30 @@ Status PipelineFragmentContext::_add_local_exchange_impl(
     // 1. Create a new pipeline with local exchange sink.
     DataSinkOperatorXPtr sink;
     auto sink_id = next_sink_operator_id();
-    const bool is_shuffled_hash_join = operator_xs.size() > idx
-                                               ? 
operator_xs[idx]->is_shuffled_hash_join()
-                                               : 
cur_pipe->sink_x()->is_shuffled_hash_join();
+
+    /**
+     * `bucket_seq_to_instance_idx` is empty if no scan operator is contained 
in this fragment.
+     * So co-located operators(e.g. Agg, Analytic) should use `HASH_SHUFFLE` 
instead of `BUCKET_HASH_SHUFFLE`.
+     */
+    const bool followed_by_shuffled_join =
+            operator_xs.size() > idx ? 
operator_xs[idx]->followed_by_shuffled_join()
+                                     : 
cur_pipe->sink_x()->followed_by_shuffled_join();
+    const bool should_disable_bucket_shuffle =
+            bucket_seq_to_instance_idx.empty() &&
+            shuffle_idx_to_instance_idx.find(-1) == 
shuffle_idx_to_instance_idx.end() &&
+            followed_by_shuffled_join;
     sink.reset(new LocalExchangeSinkOperatorX(
-            sink_id, local_exchange_id, is_shuffled_hash_join ? 
_total_instances : _num_instances,
+            sink_id, local_exchange_id,
+            should_disable_bucket_shuffle ? _total_instances : _num_instances,
             data_distribution.partition_exprs, bucket_seq_to_instance_idx));
+    if (should_disable_bucket_shuffle &&
+        data_distribution.distribution_type == 
ExchangeType::BUCKET_HASH_SHUFFLE) {
+        data_distribution.distribution_type = ExchangeType::HASH_SHUFFLE;
+    }
     RETURN_IF_ERROR(new_pip->set_sink(sink));
     
RETURN_IF_ERROR(new_pip->sink_x()->init(data_distribution.distribution_type, 
num_buckets,
-                                            is_shuffled_hash_join, 
shuffle_idx_to_instance_idx));
+                                            should_disable_bucket_shuffle,
+                                            shuffle_idx_to_instance_idx));
 
     // 2. Create and initialize LocalExchangeSharedState.
     std::shared_ptr<LocalExchangeSharedState> shared_state =
@@ -712,7 +748,7 @@ Status PipelineFragmentContext::_add_local_exchange_impl(
     case ExchangeType::HASH_SHUFFLE:
         shared_state->exchanger = ShuffleExchanger::create_unique(
                 std::max(cur_pipe->num_tasks(), _num_instances),
-                is_shuffled_hash_join ? _total_instances : _num_instances,
+                should_disable_bucket_shuffle ? _total_instances : 
_num_instances,
                 
_runtime_state->query_options().__isset.local_exchange_free_blocks_limit
                         ? 
_runtime_state->query_options().local_exchange_free_blocks_limit
                         : 0);
@@ -1124,7 +1160,8 @@ Status 
PipelineFragmentContext::_create_operator(ObjectPool* pool, const TPlanNo
                                                  const 
doris::TPipelineFragmentParams& request,
                                                  const DescriptorTbl& descs, 
OperatorXPtr& op,
                                                  PipelinePtr& cur_pipe, int 
parent_idx,
-                                                 int child_idx) {
+                                                 int child_idx,
+                                                 const bool 
followed_by_shuffled_join) {
     // We directly construct the operator from Thrift because the given array 
is in the order of preorder traversal.
     // Therefore, here we need to use a stack-like structure.
     _pipeline_parent_map.pop(cur_pipe, parent_idx, child_idx);
@@ -1215,6 +1252,7 @@ Status 
PipelineFragmentContext::_create_operator(ObjectPool* pool, const TPlanNo
             !tnode.agg_node.grouping_exprs.empty() && !group_by_limit_opt) {
             op.reset(new DistinctStreamingAggOperatorX(pool, 
next_operator_id(), tnode, descs,
                                                        
_require_bucket_distribution));
+            op->set_followed_by_shuffled_join(followed_by_shuffled_join);
             _require_bucket_distribution =
                     _require_bucket_distribution || 
op->require_data_distribution();
             RETURN_IF_ERROR(cur_pipe->add_operator(op));
@@ -1246,6 +1284,7 @@ Status 
PipelineFragmentContext::_create_operator(ObjectPool* pool, const TPlanNo
                 sink.reset(new AggSinkOperatorX(pool, next_sink_operator_id(), 
tnode, descs,
                                                 _require_bucket_distribution));
             }
+            sink->set_followed_by_shuffled_join(followed_by_shuffled_join);
             _require_bucket_distribution =
                     _require_bucket_distribution || 
sink->require_data_distribution();
             sink->set_dests_id({op->operator_id()});
@@ -1392,6 +1431,7 @@ Status 
PipelineFragmentContext::_create_operator(ObjectPool* pool, const TPlanNo
             sink.reset(new SortSinkOperatorX(pool, next_sink_operator_id(), 
tnode, descs,
                                              _require_bucket_distribution));
         }
+        sink->set_followed_by_shuffled_join(followed_by_shuffled_join);
         _require_bucket_distribution =
                 _require_bucket_distribution || 
sink->require_data_distribution();
         sink->set_dests_id({op->operator_id()});
@@ -1431,6 +1471,7 @@ Status 
PipelineFragmentContext::_create_operator(ObjectPool* pool, const TPlanNo
         DataSinkOperatorXPtr sink;
         sink.reset(new AnalyticSinkOperatorX(pool, next_sink_operator_id(), 
tnode, descs,
                                              _require_bucket_distribution));
+        sink->set_followed_by_shuffled_join(followed_by_shuffled_join);
         _require_bucket_distribution =
                 _require_bucket_distribution || 
sink->require_data_distribution();
         sink->set_dests_id({op->operator_id()});
diff --git a/be/src/pipeline/pipeline_fragment_context.h 
b/be/src/pipeline/pipeline_fragment_context.h
index 7597c3ce9b5..06c88267441 100644
--- a/be/src/pipeline/pipeline_fragment_context.h
+++ b/be/src/pipeline/pipeline_fragment_context.h
@@ -145,12 +145,13 @@ private:
     Status _create_tree_helper(ObjectPool* pool, const std::vector<TPlanNode>& 
tnodes,
                                const doris::TPipelineFragmentParams& request,
                                const DescriptorTbl& descs, OperatorXPtr 
parent, int* node_idx,
-                               OperatorXPtr* root, PipelinePtr& cur_pipe, int 
child_idx);
+                               OperatorXPtr* root, PipelinePtr& cur_pipe, int 
child_idx,
+                               const bool followed_by_shuffled_join);
 
     Status _create_operator(ObjectPool* pool, const TPlanNode& tnode,
                             const doris::TPipelineFragmentParams& request,
                             const DescriptorTbl& descs, OperatorXPtr& op, 
PipelinePtr& cur_pipe,
-                            int parent_idx, int child_idx);
+                            int parent_idx, int child_idx, const bool 
followed_by_shuffled_join);
     template <bool is_intersect>
     Status _build_operators_for_set_operation_node(ObjectPool* pool, const 
TPlanNode& tnode,
                                                    const DescriptorTbl& descs, 
OperatorXPtr& op,
diff --git 
a/regression-test/suites/nereids_p0/join/test_join_local_shuffle.groovy 
b/regression-test/suites/nereids_p0/join/test_join_local_shuffle.groovy
index c66131b57dc..29fe192e2b5 100644
--- a/regression-test/suites/nereids_p0/join/test_join_local_shuffle.groovy
+++ b/regression-test/suites/nereids_p0/join/test_join_local_shuffle.groovy
@@ -16,6 +16,10 @@
 // under the License.
 
 suite("test_join_local_shuffle", "query,p0") {
+    sql "DROP TABLE IF EXISTS test_join_local_shuffle_1;"
+    sql "DROP TABLE IF EXISTS test_join_local_shuffle_2;"
+    sql "DROP TABLE IF EXISTS test_join_local_shuffle_3;"
+    sql "DROP TABLE IF EXISTS test_join_local_shuffle_4;"
     sql "SET enable_nereids_planner=true"
     sql "SET enable_fallback_to_original_planner=false"
     sql """
@@ -72,7 +76,7 @@ suite("test_join_local_shuffle", "query,p0") {
     sql "insert into test_join_local_shuffle_2 values(2, 0);"
     sql "insert into test_join_local_shuffle_3 values(2, 0);"
     sql "insert into test_join_local_shuffle_4 values(0, 1);"
-    qt_sql " select  
/*+SET_VAR(disable_join_reorder=true,enable_local_shuffle=true) */ * from 
(select c1, max(c2) from (select b.c1 c1, b.c2 c2 from 
test_join_local_shuffle_3 a join [shuffle] test_join_local_shuffle_1 b on a.c2 
= b.c1 join [broadcast] test_join_local_shuffle_4 c on b.c1 = c.c1) t1 group by 
c1) t, test_join_local_shuffle_2 where t.c1 = test_join_local_shuffle_2.c2; "
+    qt_sql " select  
/*+SET_VAR(disable_join_reorder=true,enable_local_shuffle=true) */ * from 
(select c1, max(c2) from (select b.c1 c1, b.c2 c2 from 
test_join_local_shuffle_3 a join [shuffle] test_join_local_shuffle_1 b on a.c2 
= b.c1 join [broadcast] test_join_local_shuffle_4 c on b.c1 = c.c1) t1 group by 
c1) t join [shuffle] test_join_local_shuffle_2 on t.c1 = 
test_join_local_shuffle_2.c2; "
 
     sql "DROP TABLE IF EXISTS test_join_local_shuffle_1;"
     sql "DROP TABLE IF EXISTS test_join_local_shuffle_2;"


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to