This is an automated email from the ASF dual-hosted git repository.

BiteTheDDDDt pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 932f7898e65 [improvement](runtime-filter) Limit broadcast runtime 
filter producers (#64683)
932f7898e65 is described below

commit 932f7898e65134979ba9b6943003e6a8c4ed7891
Author: Pxl <[email protected]>
AuthorDate: Fri Jun 26 16:12:19 2026 +0800

    [improvement](runtime-filter) Limit broadcast runtime filter producers 
(#64683)
    
    Problem Summary: Broadcast join runtime filters are globally merged by
    accepting the first producer, but every build-side BE still creates and
    sends the same broadcast runtime filter. This wastes runtime filter
    build CPU and network traffic. This patch adds a Nereids-side producer
    selection controlled by `runtime_filter_broadcast_join_producer_num` so
    only up to the configured number of BEs receive broadcast runtime filter
    producer descriptors with remote targets. When the selected
    runtime-filter merge coordinator also runs the build fragment, it is
    preferred as one of the broadcast runtime filter producers to reduce
    producer-to-merge network traffic. Local-only broadcast filters and
    non-broadcast filters keep the existing planning behavior. The session
    variable is documented as Nereids distributed planner behavior; the
    legacy Coordinator path keeps the existing behavior.
    
    ### Release note
    
    Add session variable `runtime_filter_broadcast_join_producer_num` to
    limit Nereids broadcast join runtime filter producers.
---
 .../java/org/apache/doris/qe/SessionVariable.java  |  19 +++
 .../qe/runtime/RuntimeFiltersThriftBuilder.java    | 118 +++++++++++++--
 .../doris/qe/runtime/ThriftPlansBuilder.java       |  13 +-
 .../org/apache/doris/qe/SessionVariablesTest.java  |  16 +++
 .../runtime/RuntimeFiltersThriftBuilderTest.java   | 158 ++++++++++++++++++++-
 5 files changed, 312 insertions(+), 12 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
index d017b921efd..2870c3443fe 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
@@ -254,6 +254,9 @@ public class SessionVariable implements Serializable, 
Writable {
     // if the right table is greater than this value in the hash join,  we 
will ignore IN filter
     public static final String RUNTIME_FILTER_MAX_IN_NUM = 
"runtime_filter_max_in_num";
 
+    public static final String RUNTIME_FILTER_BROADCAST_JOIN_PRODUCER_NUM =
+            "runtime_filter_broadcast_join_producer_num";
+
     public static final String ENABLE_SYNC_RUNTIME_FILTER_SIZE = 
"enable_sync_runtime_filter_size";
 
     public static final String ENABLE_PARALLEL_RESULT_SINK = 
"enable_parallel_result_sink";
@@ -1723,6 +1726,14 @@ public class SessionVariable implements Serializable, 
Writable {
     @VarAttrDef.VarAttr(name = ENABLE_SYNC_RUNTIME_FILTER_SIZE, needForward = 
true, fuzzy = true)
     private boolean enableSyncRuntimeFilterSize = true;
 
+    @VarAttrDef.VarAttr(name = RUNTIME_FILTER_BROADCAST_JOIN_PRODUCER_NUM, 
needForward = true,
+            description = {"控制 Nereids 分布式规划中每个 broadcast join runtime filter 
的生产 BE 数量。"
+                    + "设置为小于等于 0 时不限制。Legacy Coordinator 路径保持原行为。",
+                    "Controls the number of producer BEs for each broadcast 
join runtime filter in "
+                    + "the Nereids distributed planner. Values less than or 
equal to 0 disable the limit. "
+                    + "The legacy Coordinator path keeps the existing 
behavior."})
+    private int runtimeFilterBroadcastJoinProducerNum = 3;
+
     @VarAttrDef.VarAttr(name = "runtime_filter_max_build_row_count", 
needForward = true, fuzzy = false)
     public long runtimeFilterMaxBuildRowCount = 64L * 1024L * 1024L;
 
@@ -4733,6 +4744,14 @@ public class SessionVariable implements Serializable, 
Writable {
         this.runtimeFilterMaxInNum = runtimeFilterMaxInNum;
     }
 
+    public int getRuntimeFilterBroadcastJoinProducerNum() {
+        return runtimeFilterBroadcastJoinProducerNum;
+    }
+
+    public void setRuntimeFilterBroadcastJoinProducerNum(int 
runtimeFilterBroadcastJoinProducerNum) {
+        this.runtimeFilterBroadcastJoinProducerNum = 
runtimeFilterBroadcastJoinProducerNum;
+    }
+
     public void setEnableLocalShuffle(boolean enableLocalShuffle) {
         this.enableLocalShuffle = enableLocalShuffle;
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/qe/runtime/RuntimeFiltersThriftBuilder.java
 
b/fe/fe-core/src/main/java/org/apache/doris/qe/runtime/RuntimeFiltersThriftBuilder.java
index cd7b0ad7bff..44c2f0e3e76 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/qe/runtime/RuntimeFiltersThriftBuilder.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/qe/runtime/RuntimeFiltersThriftBuilder.java
@@ -26,6 +26,9 @@ import org.apache.doris.planner.RuntimeFilter;
 import org.apache.doris.planner.RuntimeFilterId;
 import org.apache.doris.system.Backend;
 import org.apache.doris.thrift.TNetworkAddress;
+import org.apache.doris.thrift.TPlanFragment;
+import org.apache.doris.thrift.TPlanNode;
+import org.apache.doris.thrift.TRuntimeFilterDesc;
 import org.apache.doris.thrift.TRuntimeFilterParams;
 import org.apache.doris.thrift.TRuntimeFilterTargetParamsV2;
 
@@ -33,6 +36,8 @@ import com.google.common.base.Preconditions;
 import com.google.common.collect.Maps;
 
 import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
@@ -48,17 +53,57 @@ public class RuntimeFiltersThriftBuilder {
     private final Set<Integer> broadcastRuntimeFilterIds;
     private final Map<RuntimeFilterId, List<RuntimeFilterTarget>> ridToTargets;
     private final Map<RuntimeFilterId, Integer> ridToBuilderNum;
+    private final boolean limitBroadcastRuntimeFilterProducers;
+    private final Map<Long, List<Integer>> workerIdToBroadcastRuntimeFilterIds;
+    private final Map<Integer, Integer> 
broadcastRuntimeFilterIdToBuilderNodeId;
 
     private RuntimeFiltersThriftBuilder(
             TNetworkAddress mergeAddress, List<RuntimeFilter> runtimeFilters,
             Set<Integer> broadcastRuntimeFilterIds,
             Map<RuntimeFilterId, List<RuntimeFilterTarget>> ridToTargets,
-            Map<RuntimeFilterId, Integer> ridToBuilderNum) {
+            Map<RuntimeFilterId, Integer> ridToBuilderNum,
+            boolean limitBroadcastRuntimeFilterProducers,
+            Map<Long, List<Integer>> workerIdToBroadcastRuntimeFilterIds,
+            Map<Integer, Integer> broadcastRuntimeFilterIdToBuilderNodeId) {
         this.mergeAddress = mergeAddress;
         this.runtimeFilters = runtimeFilters;
         this.broadcastRuntimeFilterIds = broadcastRuntimeFilterIds;
         this.ridToTargets = ridToTargets;
         this.ridToBuilderNum = ridToBuilderNum;
+        this.limitBroadcastRuntimeFilterProducers = 
limitBroadcastRuntimeFilterProducers;
+        this.workerIdToBroadcastRuntimeFilterIds = 
workerIdToBroadcastRuntimeFilterIds;
+        this.broadcastRuntimeFilterIdToBuilderNodeId = 
broadcastRuntimeFilterIdToBuilderNodeId;
+    }
+
+    public void pruneBroadcastRuntimeFilterProducers(
+            TPlanFragment planFragment, DistributedPlanWorker worker) {
+        if (!limitBroadcastRuntimeFilterProducers) {
+            return;
+        }
+        Set<Integer> producerFilterIds = new HashSet<>(
+                workerIdToBroadcastRuntimeFilterIds.getOrDefault(worker.id(), 
Collections.emptyList()));
+        if (planFragment.isSetPlan()) {
+            for (TPlanNode node : planFragment.getPlan().getNodes()) {
+                if (node.isSetRuntimeFilters()) {
+                    node.setRuntimeFilters(pruneRuntimeFilterDescs(
+                            node.getNodeId(), node.getRuntimeFilters(), 
producerFilterIds));
+                }
+            }
+        }
+    }
+
+    private List<TRuntimeFilterDesc> pruneRuntimeFilterDescs(
+            int nodeId, List<TRuntimeFilterDesc> runtimeFilterDescs, 
Set<Integer> producerFilterIds) {
+        List<TRuntimeFilterDesc> selectedRuntimeFilterDescs = new 
ArrayList<>(runtimeFilterDescs.size());
+        for (TRuntimeFilterDesc desc : runtimeFilterDescs) {
+            Integer builderNodeId = 
broadcastRuntimeFilterIdToBuilderNodeId.get(desc.filter_id);
+            if (!desc.is_broadcast_join || !desc.has_remote_targets
+                    || builderNodeId == null || builderNodeId != nodeId
+                    || producerFilterIds.contains(desc.filter_id)) {
+                selectedRuntimeFilterDescs.add(desc);
+            }
+        }
+        return selectedRuntimeFilterDescs;
     }
 
     public void populateRuntimeFilterParams(TRuntimeFilterParams 
runtimeFilterParams) {
@@ -100,9 +145,20 @@ public class RuntimeFiltersThriftBuilder {
 
     public static RuntimeFiltersThriftBuilder compute(
             List<RuntimeFilter> runtimeFilters, List<PipelineDistributedPlan> 
distributedPlans) {
+        return compute(runtimeFilters, distributedPlans, 0);
+    }
+
+    public static RuntimeFiltersThriftBuilder compute(
+            List<RuntimeFilter> runtimeFilters, List<PipelineDistributedPlan> 
distributedPlans,
+            int broadcastRuntimeFilterProducerNum) {
         BackendWorker worker = selectMergeWorker(distributedPlans);
         TNetworkAddress mergeAddress = new TNetworkAddress(worker.host(), 
worker.brpcPort());
 
+        Map<Integer, RuntimeFilter> idToRuntimeFilter = runtimeFilters
+                .stream()
+                .collect(Collectors.toMap(r -> r.getFilterId().asInt(), r -> 
r, (left, right) -> left,
+                        LinkedHashMap::new));
+
         Set<Integer> broadcastRuntimeFilterIds = runtimeFilters
                 .stream()
                 .filter(RuntimeFilter::isBroadcast)
@@ -111,6 +167,10 @@ public class RuntimeFiltersThriftBuilder {
 
         Map<RuntimeFilterId, List<RuntimeFilterTarget>> ridToTargetParam = 
Maps.newLinkedHashMap();
         Map<RuntimeFilterId, Integer> ridToBuilderNum = 
Maps.newLinkedHashMap();
+        Map<Integer, List<BackendWorker>> builderNodeToProducerWorkers = 
Maps.newLinkedHashMap();
+        Map<Long, List<Integer>> workerIdToBroadcastRuntimeFilterIds = 
Maps.newLinkedHashMap();
+        Map<Integer, Integer> broadcastRuntimeFilterIdToBuilderNodeId = 
Maps.newLinkedHashMap();
+        boolean limitBroadcastRuntimeFilterProducers = 
broadcastRuntimeFilterProducerNum > 0;
         for (PipelineDistributedPlan plan : distributedPlans) {
             PlanFragment fragment = plan.getFragmentJob().getFragment();
             // Transform <fragment, runtimeFilterId> to <runtimeFilterId, 
fragment>
@@ -126,18 +186,60 @@ public class RuntimeFiltersThriftBuilder {
                 }
             }
 
+            List<BackendWorker> builderWorkers = 
collectDistinctBackendWorkers(plan.getInstanceJobs());
+            int distinctWorkerNum = builderWorkers.size();
             for (RuntimeFilterId rid : fragment.getBuilderRuntimeFilterIds()) {
-                int distinctWorkerNum = (int) plan.getInstanceJobs()
-                        .stream()
-                        .map(AssignedJob::getAssignedWorker)
-                        .map(DistributedPlanWorker::id)
-                        .distinct()
-                        .count();
                 ridToBuilderNum.merge(rid, distinctWorkerNum, Integer::sum);
+                RuntimeFilter rf = idToRuntimeFilter.get(rid.asInt());
+                if (limitBroadcastRuntimeFilterProducers
+                        && rf != null && rf.isBroadcast() && 
rf.hasRemoteTargets()) {
+                    int builderNodeId = rf.getBuilderNode().getId().asInt();
+                    broadcastRuntimeFilterIdToBuilderNodeId.put(rid.asInt(), 
builderNodeId);
+                    List<BackendWorker> producerWorkers = 
builderNodeToProducerWorkers.computeIfAbsent(
+                            builderNodeId,
+                            id -> selectBroadcastRuntimeFilterProducerWorkers(
+                                    builderWorkers, 
broadcastRuntimeFilterProducerNum, worker));
+                    for (BackendWorker producerWorker : producerWorkers) {
+                        workerIdToBroadcastRuntimeFilterIds.computeIfAbsent(
+                                producerWorker.id(), id -> new 
ArrayList<>()).add(rid.asInt());
+                    }
+                }
             }
         }
         return new RuntimeFiltersThriftBuilder(
-                mergeAddress, runtimeFilters, broadcastRuntimeFilterIds, 
ridToTargetParam, ridToBuilderNum);
+                mergeAddress, runtimeFilters, broadcastRuntimeFilterIds, 
ridToTargetParam, ridToBuilderNum,
+                limitBroadcastRuntimeFilterProducers, 
workerIdToBroadcastRuntimeFilterIds,
+                broadcastRuntimeFilterIdToBuilderNodeId);
+    }
+
+    static List<BackendWorker> collectDistinctBackendWorkers(List<AssignedJob> 
instanceJobs) {
+        Map<Long, BackendWorker> workerMap = Maps.newLinkedHashMap();
+        for (AssignedJob instanceJob : instanceJobs) {
+            BackendWorker worker = (BackendWorker) 
instanceJob.getAssignedWorker();
+            workerMap.putIfAbsent(worker.id(), worker);
+        }
+        return new ArrayList<>(workerMap.values());
+    }
+
+    static List<BackendWorker> selectBroadcastRuntimeFilterProducerWorkers(
+            List<BackendWorker> workers, int producerNum, BackendWorker 
preferredWorker) {
+        Preconditions.checkArgument(producerNum > 0,
+                "broadcast runtime filter producer num must be positive");
+        if (workers.size() <= producerNum) {
+            return workers;
+        }
+        List<BackendWorker> selectedWorkers = new ArrayList<>(producerNum);
+        for (BackendWorker worker : workers) {
+            if (worker.equals(preferredWorker)) {
+                selectedWorkers.add(worker);
+                break;
+            }
+        }
+        List<BackendWorker> remainingWorkers = new ArrayList<>(workers);
+        remainingWorkers.removeAll(selectedWorkers);
+        Collections.shuffle(remainingWorkers, ThreadLocalRandom.current());
+        selectedWorkers.addAll(remainingWorkers.subList(0, producerNum - 
selectedWorkers.size()));
+        return selectedWorkers;
     }
 
     static BackendWorker selectMergeWorker(List<PipelineDistributedPlan> 
distributedPlans) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/qe/runtime/ThriftPlansBuilder.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/runtime/ThriftPlansBuilder.java
index 7c95ef7dfda..672b3d245f4 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/qe/runtime/ThriftPlansBuilder.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/qe/runtime/ThriftPlansBuilder.java
@@ -116,8 +116,12 @@ public class ThriftPlansBuilder {
         // we should set runtime predicate first, then we can use heap sort 
and to thrift
         setRuntimePredicateIfNeed(coordinatorContext.scanNodes);
 
+        int broadcastRuntimeFilterProducerNum = 
coordinatorContext.connectContext == null
+                ? 0
+                : coordinatorContext.connectContext.getSessionVariable()
+                        .getRuntimeFilterBroadcastJoinProducerNum();
         RuntimeFiltersThriftBuilder runtimeFiltersThriftBuilder = 
RuntimeFiltersThriftBuilder.compute(
-                coordinatorContext.runtimeFilters, distributedPlans);
+                coordinatorContext.runtimeFilters, distributedPlans, 
broadcastRuntimeFilterProducerNum);
         Supplier<List<TTopnFilterDesc>> topNFilterThriftSupplier
                 = topNFilterToThrift(coordinatorContext.topnFilters);
 
@@ -137,7 +141,8 @@ public class ThriftPlansBuilder {
                 TPipelineFragmentParams currentFragmentParam = 
fragmentToThriftIfAbsent(
                         currentFragmentPlan, instanceJob, 
workerToCurrentFragment,
                         instancesPerWorker, exchangeSenderNum, 
sharedFileScanRangeParams,
-                        workerProcessInstanceNum, fragmentToNotifyClose, 
coordinatorContext);
+                        workerProcessInstanceNum, fragmentToNotifyClose, 
coordinatorContext,
+                        runtimeFiltersThriftBuilder);
 
                 TPipelineInstanceParams instanceParam = instanceToThrift(
                         currentFragmentParam, instanceJob, 
currentInstanceIndex++);
@@ -369,7 +374,8 @@ public class ThriftPlansBuilder {
             Map<Integer, TFileScanRangeParams> fileScanRangeParamsMap,
             Multiset<DistributedPlanWorker> workerProcessInstanceNum,
             Set<Integer> fragmentToNotifyClose,
-            CoordinatorContext coordinatorContext) {
+            CoordinatorContext coordinatorContext,
+            RuntimeFiltersThriftBuilder runtimeFiltersThriftBuilder) {
         DistributedPlanWorker worker = assignedJob.getAssignedWorker();
         return workerToFragmentParams.computeIfAbsent(worker, w -> {
             PlanFragment fragment = 
fragmentPlan.getFragmentJob().getFragment();
@@ -424,6 +430,7 @@ public class ThriftPlansBuilder {
             
params.setSendQueryStatisticsWithEveryBatch(fragment.isTransferQueryStatisticsWithEveryBatch());
 
             TPlanFragment planThrift = fragment.toThrift();
+            
runtimeFiltersThriftBuilder.pruneBroadcastRuntimeFilterProducers(planThrift, 
worker);
             planThrift.query_cache_param = fragment.queryCacheParam;
             params.setFragment(planThrift);
             params.setLocalParams(Lists.newArrayList());
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/qe/SessionVariablesTest.java 
b/fe/fe-core/src/test/java/org/apache/doris/qe/SessionVariablesTest.java
index be7b42cfe77..a8446efb451 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/qe/SessionVariablesTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/qe/SessionVariablesTest.java
@@ -155,6 +155,22 @@ public class SessionVariablesTest extends 
TestWithFeService {
                 () -> sessionVar.checkInsertVisibleTimeoutReturnMode(""));
     }
 
+    @Test
+    public void testRuntimeFilterBroadcastJoinProducerNumDescription() throws 
Exception {
+        SessionVariable sessionVar = new SessionVariable();
+        Assertions.assertEquals(3, 
sessionVar.getRuntimeFilterBroadcastJoinProducerNum());
+
+        Field field = 
SessionVariable.class.getDeclaredField("runtimeFilterBroadcastJoinProducerNum");
+        VarAttrDef.VarAttr varAttr = 
field.getAnnotation(VarAttrDef.VarAttr.class);
+        Assertions.assertArrayEquals(new String[] {
+                "控制 Nereids 分布式规划中每个 broadcast join runtime filter 的生产 BE 数量。"
+                        + "设置为小于等于 0 时不限制。Legacy Coordinator 路径保持原行为。",
+                "Controls the number of producer BEs for each broadcast join 
runtime filter in "
+                        + "the Nereids distributed planner. Values less than 
or equal to 0 disable the limit. "
+                        + "The legacy Coordinator path keeps the existing 
behavior."
+        }, varAttr.description());
+    }
+
     @Test
     public void testForceEagerAggHintParseWhenSetSessionVariable() throws 
Exception {
         SessionVariable sessionVar = new SessionVariable();
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/qe/runtime/RuntimeFiltersThriftBuilderTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/qe/runtime/RuntimeFiltersThriftBuilderTest.java
index b3d4b4f80ce..ab8f50e0839 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/qe/runtime/RuntimeFiltersThriftBuilderTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/qe/runtime/RuntimeFiltersThriftBuilderTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.qe.runtime;
 
+import org.apache.doris.common.IdGenerator;
 import org.apache.doris.nereids.StatementContext;
 import org.apache.doris.nereids.trees.AbstractTreeNode;
 import org.apache.doris.nereids.trees.plans.distribute.DistributeContext;
@@ -27,8 +28,17 @@ import 
org.apache.doris.nereids.trees.plans.distribute.worker.job.DefaultScanSou
 import 
org.apache.doris.nereids.trees.plans.distribute.worker.job.UnassignedJob;
 import org.apache.doris.planner.ExchangeNode;
 import org.apache.doris.planner.PlanFragment;
+import org.apache.doris.planner.PlanFragmentId;
+import org.apache.doris.planner.PlanNode;
+import org.apache.doris.planner.PlanNodeId;
+import org.apache.doris.planner.RuntimeFilter;
+import org.apache.doris.planner.RuntimeFilterId;
 import org.apache.doris.planner.ScanNode;
 import org.apache.doris.system.Backend;
+import org.apache.doris.thrift.TPlan;
+import org.apache.doris.thrift.TPlanFragment;
+import org.apache.doris.thrift.TPlanNode;
+import org.apache.doris.thrift.TRuntimeFilterDesc;
 import org.apache.doris.thrift.TUniqueId;
 
 import com.google.common.collect.ArrayListMultimap;
@@ -36,10 +46,13 @@ import com.google.common.collect.ImmutableSetMultimap;
 import com.google.common.collect.ListMultimap;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
 
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 import java.util.stream.Collectors;
 
 public class RuntimeFiltersThriftBuilderTest {
@@ -59,6 +72,96 @@ public class RuntimeFiltersThriftBuilderTest {
         
Assertions.assertTrue(candidates.contains(RuntimeFiltersThriftBuilder.selectMergeWorker(distributedPlans)));
     }
 
+    @Test
+    public void 
testSelectBroadcastRuntimeFilterProducerWorkersPreferMergeWorker() {
+        BackendWorker worker1 = newBackendWorker(1);
+        BackendWorker worker2 = newBackendWorker(2);
+        BackendWorker worker3 = newBackendWorker(3);
+
+        List<BackendWorker> selectedWorkers = 
RuntimeFiltersThriftBuilder.selectBroadcastRuntimeFilterProducerWorkers(
+                Arrays.asList(worker1, worker2, worker3), 1, worker2);
+
+        Assertions.assertEquals(Collections.singletonList(worker2), 
selectedWorkers);
+    }
+
+    @Test
+    public void 
testSelectBroadcastRuntimeFilterProducerWorkersIgnoreNonCandidatePreferredWorker()
 {
+        BackendWorker worker1 = newBackendWorker(1);
+        BackendWorker worker2 = newBackendWorker(2);
+        BackendWorker worker3 = newBackendWorker(3);
+
+        List<BackendWorker> selectedWorkers = 
RuntimeFiltersThriftBuilder.selectBroadcastRuntimeFilterProducerWorkers(
+                Arrays.asList(worker1, worker2), 1, worker3);
+
+        Assertions.assertEquals(1, selectedWorkers.size());
+        Assertions.assertTrue(Arrays.asList(worker1, 
worker2).contains(selectedWorkers.get(0)));
+        Assertions.assertFalse(selectedWorkers.contains(worker3));
+    }
+
+    @Test
+    public void testPruneBroadcastRuntimeFilterProducers() {
+        BackendWorker worker1 = newBackendWorker(1);
+        BackendWorker worker2 = newBackendWorker(2);
+        BackendWorker worker3 = newBackendWorker(3);
+
+        IdGenerator<RuntimeFilterId> idGenerator = 
RuntimeFilterId.createGenerator();
+        RuntimeFilterId broadcastRid1 = idGenerator.getNextId();
+        RuntimeFilterId broadcastRid2 = idGenerator.getNextId();
+        RuntimeFilterId localBroadcastRid = idGenerator.getNextId();
+        RuntimeFilterId shuffleRid = idGenerator.getNextId();
+        RuntimeFilter broadcastRf1 = newRuntimeFilter(broadcastRid1, 10, true, 
true);
+        RuntimeFilter broadcastRf2 = newRuntimeFilter(broadcastRid2, 10, true, 
true);
+        RuntimeFilter localBroadcastRf = newRuntimeFilter(localBroadcastRid, 
10, true, false);
+        RuntimeFilter shuffleRf = newRuntimeFilter(shuffleRid, 11, false, 
true);
+
+        PlanFragment fragment = newFragment(0, broadcastRid1, broadcastRid2, 
localBroadcastRid, shuffleRid);
+        PipelineDistributedPlan distributedPlan = newDistributedPlan(fragment, 
worker1, worker2, worker3);
+        RuntimeFiltersThriftBuilder builder = 
RuntimeFiltersThriftBuilder.compute(
+                Arrays.asList(broadcastRf1, broadcastRf2, localBroadcastRf, 
shuffleRf),
+                Collections.singletonList(distributedPlan), 1);
+
+        int selectedWorkerNum = 0;
+        for (BackendWorker worker : Arrays.asList(worker1, worker2, worker3)) {
+            TPlanFragment planFragment = newPlanFragment(
+                    newPlanNode(10,
+                            newRuntimeFilterDesc(broadcastRid1, true, true),
+                            newRuntimeFilterDesc(broadcastRid2, true, true),
+                            newRuntimeFilterDesc(localBroadcastRid, true, 
false)),
+                    newPlanNode(11,
+                            newRuntimeFilterDesc(shuffleRid, false, true)),
+                    newPlanNode(20,
+                            newRuntimeFilterDesc(broadcastRid1, true, true),
+                            newRuntimeFilterDesc(broadcastRid2, true, true)));
+            builder.pruneBroadcastRuntimeFilterProducers(planFragment, worker);
+
+            List<Integer> builderFilterIds = 
planFragment.getPlan().getNodes().get(0).getRuntimeFilters()
+                    .stream()
+                    .map(desc -> desc.filter_id)
+                    .collect(Collectors.toList());
+            List<Integer> shuffleFilterIds = 
planFragment.getPlan().getNodes().get(1).getRuntimeFilters()
+                    .stream()
+                    .map(desc -> desc.filter_id)
+                    .collect(Collectors.toList());
+            List<Integer> targetFilterIds = 
planFragment.getPlan().getNodes().get(2).getRuntimeFilters()
+                    .stream()
+                    .map(desc -> desc.filter_id)
+                    .collect(Collectors.toList());
+            
Assertions.assertEquals(Collections.singletonList(shuffleRid.asInt()), 
shuffleFilterIds);
+            Assertions.assertEquals(new HashSet<>(Arrays.asList(
+                    broadcastRid1.asInt(), broadcastRid2.asInt())), new 
HashSet<>(targetFilterIds));
+            if (builderFilterIds.contains(broadcastRid1.asInt())) {
+                selectedWorkerNum++;
+                Assertions.assertEquals(new HashSet<>(Arrays.asList(
+                        broadcastRid1.asInt(), broadcastRid2.asInt(), 
localBroadcastRid.asInt())),
+                        new HashSet<>(builderFilterIds));
+            } else {
+                
Assertions.assertFalse(builderFilterIds.contains(broadcastRid2.asInt()));
+                
Assertions.assertEquals(Collections.singletonList(localBroadcastRid.asInt()), 
builderFilterIds);
+            }
+        }
+        Assertions.assertEquals(1, selectedWorkerNum);
+    }
+
     private BackendWorker newBackendWorker(long id) {
         Backend backend = new Backend(id, "host" + id, (int) (9000 + id));
         backend.setBePort((int) (8000 + id));
@@ -67,7 +170,12 @@ public class RuntimeFiltersThriftBuilderTest {
     }
 
     private PipelineDistributedPlan newDistributedPlan(BackendWorker... 
workers) {
+        return newDistributedPlan(null, workers);
+    }
+
+    private PipelineDistributedPlan newDistributedPlan(PlanFragment fragment, 
BackendWorker... workers) {
         TestUnassignedJob unassignedJob = new TestUnassignedJob();
+        unassignedJob.fragment = fragment;
         List<AssignedJob> assignedJobs = Arrays.stream(workers)
                 .map(worker -> unassignedJob.assignWorkerAndDataSources(
                         0, new TUniqueId(), worker, DefaultScanSource.empty()))
@@ -75,7 +183,55 @@ public class RuntimeFiltersThriftBuilderTest {
         return new PipelineDistributedPlan(unassignedJob, assignedJobs, 
ImmutableSetMultimap.of());
     }
 
+    private RuntimeFilter newRuntimeFilter(RuntimeFilterId rid, int 
builderNodeId,
+            boolean isBroadcast, boolean hasRemoteTargets) {
+        PlanNode builderNode = Mockito.mock(PlanNode.class);
+        Mockito.when(builderNode.getId()).thenReturn(new 
PlanNodeId(builderNodeId));
+
+        RuntimeFilter runtimeFilter = Mockito.mock(RuntimeFilter.class);
+        Mockito.when(runtimeFilter.getFilterId()).thenReturn(rid);
+        Mockito.when(runtimeFilter.getBuilderNode()).thenReturn(builderNode);
+        Mockito.when(runtimeFilter.isBroadcast()).thenReturn(isBroadcast);
+        
Mockito.when(runtimeFilter.hasRemoteTargets()).thenReturn(hasRemoteTargets);
+        return runtimeFilter;
+    }
+
+    private TPlanFragment newPlanFragment(TPlanNode... nodes) {
+        TPlan plan = new TPlan();
+        plan.setNodes(Arrays.asList(nodes));
+        TPlanFragment fragment = new TPlanFragment();
+        fragment.setPlan(plan);
+        return fragment;
+    }
+
+    private TPlanNode newPlanNode(int nodeId, TRuntimeFilterDesc... 
runtimeFilterDescs) {
+        TPlanNode node = new TPlanNode();
+        node.setNodeId(nodeId);
+        node.setRuntimeFilters(Arrays.asList(runtimeFilterDescs));
+        return node;
+    }
+
+    private TRuntimeFilterDesc newRuntimeFilterDesc(
+            RuntimeFilterId rid, boolean isBroadcast, boolean 
hasRemoteTargets) {
+        TRuntimeFilterDesc desc = new TRuntimeFilterDesc();
+        desc.setFilterId(rid.asInt());
+        desc.setIsBroadcastJoin(isBroadcast);
+        desc.setHasRemoteTargets(hasRemoteTargets);
+        return desc;
+    }
+
+    private PlanFragment newFragment(int fragmentId, RuntimeFilterId... 
builderRuntimeFilterIds) {
+        PlanFragment fragment = Mockito.mock(PlanFragment.class);
+        Mockito.when(fragment.getFragmentId()).thenReturn(new 
PlanFragmentId(fragmentId));
+        Set<RuntimeFilterId> builderIds = new 
HashSet<>(Arrays.asList(builderRuntimeFilterIds));
+        
Mockito.when(fragment.getBuilderRuntimeFilterIds()).thenReturn(builderIds);
+        
Mockito.when(fragment.getTargetRuntimeFilterIds()).thenReturn(Collections.emptySet());
+        return fragment;
+    }
+
     private static final class TestUnassignedJob extends 
AbstractTreeNode<UnassignedJob> implements UnassignedJob {
+        private PlanFragment fragment;
+
         private TestUnassignedJob() {
             super(Collections.emptyList());
         }
@@ -87,7 +243,7 @@ public class RuntimeFiltersThriftBuilderTest {
 
         @Override
         public PlanFragment getFragment() {
-            return null;
+            return fragment;
         }
 
         @Override


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to