This is an automated email from the ASF dual-hosted git repository.

gortiz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new fd3b079615 Refactor DispatchableSubPlan to store fragments by stage id 
(#15135)
fd3b079615 is described below

commit fd3b0796156d51ce4239cca02fd925bf1728c885
Author: Gonzalo Ortiz Jaureguizar <[email protected]>
AuthorDate: Fri Feb 28 15:06:54 2025 +0100

    Refactor DispatchableSubPlan to store fragments by stage id (#15135)
    
    Previously, they were stored on a list whose id was the stage id. This made 
sense when stage ids formed a dense set, but spools broke that assumption.
---
 .../MultiStageBrokerRequestHandler.java            |   6 +-
 .../integration/tests/SpoolIntegrationTest.java    | 161 ++++++++++++++++++++-
 .../org/apache/pinot/query/QueryEnvironment.java   |   2 +-
 .../apache/pinot/query/planner/PlanFragment.java   |   8 +
 .../explain/PhysicalExplainPlanVisitor.java        |  12 +-
 .../planner/physical/DispatchablePlanContext.java  |  38 +++--
 .../planner/physical/DispatchableSubPlan.java      |  52 ++++++-
 .../planner/physical/PinotDispatchPlanner.java     |   2 +-
 .../apache/pinot/query/QueryCompilationTest.java   |  41 +++---
 .../query/planner/serde/PlanNodeSerDeTest.java     |   2 +-
 .../query/runtime/MultiStageStatsTreeBuilder.java  |  15 +-
 .../query/service/dispatch/QueryDispatcher.java    |  47 +++---
 .../query/runtime/queries/QueryRunnerTestBase.java |  14 +-
 .../runtime/queries/ResourceBasedQueriesTest.java  |   2 +-
 .../query/service/server/QueryServerTest.java      |  23 ++-
 15 files changed, 321 insertions(+), 104 deletions(-)

diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
index 771e978341..80dba42f8e 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
@@ -263,7 +263,7 @@ public class MultiStageBrokerRequestHandler extends 
BaseBrokerRequestHandler {
     DispatchableSubPlan dispatchableSubPlan = queryPlanResult.getQueryPlan();
 
     Set<QueryServerInstance> servers = new HashSet<>();
-    for (DispatchablePlanFragment planFragment: 
dispatchableSubPlan.getQueryStageList()) {
+    for (DispatchablePlanFragment planFragment: 
dispatchableSubPlan.getQueryStageMap().values()) {
       servers.addAll(planFragment.getServerInstances());
     }
 
@@ -443,9 +443,9 @@ public class MultiStageBrokerRequestHandler extends 
BaseBrokerRequestHandler {
   private void fillOldBrokerResponseStats(BrokerResponseNativeV2 
brokerResponse,
       List<MultiStageQueryStats.StageStats.Closed> queryStats, 
DispatchableSubPlan dispatchableSubPlan) {
     try {
-      List<DispatchablePlanFragment> stagePlans = 
dispatchableSubPlan.getQueryStageList();
+      Map<Integer, DispatchablePlanFragment> queryStageMap = 
dispatchableSubPlan.getQueryStageMap();
 
-      MultiStageStatsTreeBuilder treeBuilder = new 
MultiStageStatsTreeBuilder(stagePlans, queryStats);
+      MultiStageStatsTreeBuilder treeBuilder = new 
MultiStageStatsTreeBuilder(queryStageMap, queryStats);
       brokerResponse.setStageStats(treeBuilder.jsonStatsByStage(0));
       for (MultiStageQueryStats.StageStats.Closed stageStats : queryStats) {
         if (stageStats != null) { // for example pipeline breaker may not have 
stats
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/SpoolIntegrationTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/SpoolIntegrationTest.java
index 91f86f4f65..1c37808401 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/SpoolIntegrationTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/SpoolIntegrationTest.java
@@ -22,13 +22,17 @@ import com.fasterxml.jackson.databind.JsonNode;
 import com.jayway.jsonpath.DocumentContext;
 import com.jayway.jsonpath.JsonPath;
 import java.io.File;
+import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.spi.config.table.TableConfig;
 import org.apache.pinot.spi.data.Schema;
 import org.apache.pinot.spi.env.PinotConfiguration;
 import org.apache.pinot.spi.utils.CommonConstants;
 import org.apache.pinot.util.TestUtils;
+import org.intellij.lang.annotations.Language;
 import org.testcontainers.shaded.org.apache.commons.io.FileUtils;
 import org.testng.Assert;
 import org.testng.annotations.AfterClass;
@@ -117,13 +121,160 @@ public class SpoolIntegrationTest extends 
BaseClusterIntegrationTest
     JsonNode stats = jsonNode.get("stageStats");
     assertNoError(jsonNode);
     DocumentContext parsed = JsonPath.parse(stats.toString());
-    List<Map<String, Object>> stage4On3 = parsed.read("$..[?(@.stage == 
3)]..[?(@.stage == 4)]");
-    Assert.assertEquals(stage4On3.size(), 1, "Stage 4 should be descended from 
stage 3 exactly once");
 
-    List<Map<String, Object>> stage4On7 = parsed.read("$..[?(@.stage == 
7)]..[?(@.stage == 4)]");
-    Assert.assertEquals(stage4On3.size(), 1, "Stage 4 should be descended from 
stage 7 exactly once");
+    checkSpoolTimes(parsed, 4, 3, 1);
+    checkSpoolTimes(parsed, 4, 7, 1);
+    checkSpoolSame(parsed, 4, 3, 7);
+  }
+
+  /**
+   * Test a complex with nested spools.
+   *
+   * Don't try to understand the query, just check that the spools are correct.
+   * This query is an actual simplification of a query used in production.
+   * It was the way we detected problems fixed in <a 
href="https://github.com/apache/pinot/pull/15135";>#15135</a>.
+   */
+  @Test
+  public void testNestedSpools()
+      throws Exception {
+    JsonNode jsonNode = postQuery("SET useSpools = true;\n"
+        + "\n"
+        + "WITH\n"
+        + "    q1 AS (\n"
+        + "        SELECT ArrTimeBlk as userUUID,\n"
+        + "               Dest as deviceOS,\n"
+        + "               SUM(ArrTime) AS totalTrips\n"
+        + "        FROM mytable\n"
+        + "        GROUP BY ArrTimeBlk, Dest\n"
+        + "    ),\n"
+        + "     q2 AS (\n"
+        + "         SELECT userUUID,\n"
+        + "                deviceOS,\n"
+        + "                SUM(totalTrips) AS totalTrips,\n"
+        + "                COUNT(DISTINCT userUUID) AS reach\n"
+        + "         FROM q1\n"
+        + "         GROUP BY userUUID,\n"
+        + "                  deviceOS\n"
+        + "     ),\n"
+        + "     q3 AS (\n"
+        + "         SELECT userUUID,\n"
+        + "                (totalTrips / reach) AS frequency\n"
+        + "         FROM q2\n"
+        + "     ),\n"
+        + "     q4 AS (\n"
+        + "         SELECT rd.userUUID,\n"
+        + "                rd.deviceOS,\n"
+        + "                rd.totalTrips as totalTrips,\n"
+        + "                rd.reach AS reach\n"
+        + "         FROM q2 rd\n"
+        + "     ),\n"
+        + "     q5 AS (\n"
+        + "         SELECT userUUID,\n"
+        + "                SUM(totalTrips) AS totalTrips\n"
+        + "         FROM q4\n"
+        + "         GROUP BY userUUID\n"
+        + "     ),\n"
+        + "     q6 AS (\n"
+        + "         SELECT s.userUUID,\n"
+        + "                s.totalTrips,\n"
+        + "                (s.totalTrips / o.frequency) AS reach,\n"
+        + "                'some fake device' AS deviceOS\n"
+        + "         FROM q5 s\n"
+        + "                  JOIN q3 o ON s.userUUID = o.userUUID\n"
+        + "     ),\n"
+        + "     q7 AS (\n"
+        + "         SELECT rd.userUUID,\n"
+        + "                rd.totalTrips,\n"
+        + "                rd.reach,\n"
+        + "                rd.deviceOS\n"
+        + "         FROM q4 rd\n"
+        + "         UNION ALL\n"
+        + "         SELECT f.userUUID,\n"
+        + "                f.totalTrips,\n"
+        + "                f.reach,\n"
+        + "                f.deviceOS\n"
+        + "         FROM q6 f\n"
+        + "     ),\n"
+        + "     q8 AS (\n"
+        + "         SELECT sd.*\n"
+        + "         FROM q7 sd\n"
+        + "                  JOIN (\n"
+        + "             SELECT deviceOS,\n"
+        + "                    PERCENTILETDigest(totalTrips, 20) AS p20\n"
+        + "             FROM q7\n"
+        + "             GROUP BY deviceOS\n"
+        + "         ) q ON sd.deviceOS = q.deviceOS\n"
+        + "     )\n"
+        + "SELECT *\n"
+        + "FROM q8");
+    JsonNode stats = jsonNode.get("stageStats");
+    assertNoError(jsonNode);
+    DocumentContext parsed = JsonPath.parse(stats.toString());
+
+    /*
+     * Stages are like follows:
+     * 1 -> 2 (union)  ->  3 (aggr) ->  4 (leaf, spooled)
+     *                 ->  5 (join) ->  6 (aggr, spooled) ->  7 (aggr, 
spooled) -> 4 (leaf, spooled)
+     *                              ->  9 (aggr)          ->  4 (leaf, spooled)
+     *   -> 11 (union) -> 12 (aggr) ->  4 (leaf, spooled)
+     *                 -> 14 (join) ->  6 (aggr, spooled) ->  7 (aggr, 
spooled) -> 4 (leaf, spooled)
+     *                              -> 18 (aggr)          ->  4 (leaf, spooled)
+     */
+    checkSpoolTimes(parsed, 6, 5, 1);
+    checkSpoolTimes(parsed, 6, 14, 1);
+    checkSpoolSame(parsed, 6, 5, 14);
+
+    checkSpoolTimes(parsed, 7, 6, 2);
+
+    checkSpoolTimes(parsed, 4, 3, 1);
+    checkSpoolTimes(parsed, 4, 7, 2); // because there are 2 copies of 7 as 
well
+    checkSpoolTimes(parsed, 4, 9, 1);
+    checkSpoolTimes(parsed, 4, 12, 1);
+    checkSpoolTimes(parsed, 4, 18, 1);
+    checkSpoolSame(parsed, 4, 3, 7, 9, 12, 18);
+  }
+
+  /**
+   * Returns the nodes that have the given descendant id and also have the 
given parent id as one of their ancestors.
+   * @param parent the parent id
+   * @param descendant the descendant id
+   */
+  private List<Map<String, Object>> findDescendantById(DocumentContext stats, 
int parent, int descendant) {
+    @Language("jsonpath")
+    String jsonPath = "$..[?(@.stage == " + parent + ")]..[?(@.stage == " + 
descendant + ")]";
+    return stats.read(jsonPath);
+  }
+
+  private void checkSpoolTimes(DocumentContext stats, int spoolStageId, int 
parent, int times) {
+    List<Map<String, Object>> descendants = findDescendantById(stats, parent, 
spoolStageId);
+    Assert.assertEquals(descendants.size(), times, "Stage " + spoolStageId + " 
should be descended from stage "
+        + parent + " exactly " + times + " times");
+    Map<String, Object> firstSpool = descendants.get(0);
+    for (int i = 1; i < descendants.size(); i++) {
+      Assert.assertEquals(descendants.get(i), firstSpool, "Stage " + 
spoolStageId + " should be the same in "
+          + "all " + times + " descendants");
+    }
+  }
 
-    Assert.assertEquals(stage4On3, stage4On7, "Stage 4 should be the same in 
both stage 3 and stage 7");
+  private void checkSpoolSame(DocumentContext stats, int spoolStageId, int... 
parents) {
+    List<Pair<Integer, List<Map<String, Object>>>> spools = 
Arrays.stream(parents)
+        .mapToObj(parent -> Pair.of(parent, findDescendantById(stats, parent, 
spoolStageId)))
+        .collect(Collectors.toList());
+    Pair<Integer, List<Map<String, Object>>> notEmpty = spools.stream()
+        .filter(s -> !s.getValue().isEmpty())
+        .findFirst()
+        .orElse(null);
+    if (notEmpty == null) {
+      Assert.fail("None of the parent nodes " + Arrays.toString(parents) + " 
have a descendant with id "
+          + spoolStageId);
+    }
+    List<Pair<Integer, List<Map<String, Object>>>> allNotEqual = 
spools.stream()
+        .filter(s -> !s.getValue().get(0).equals(notEmpty.getValue().get(0)))
+        .collect(Collectors.toList());
+    if (!allNotEqual.isEmpty()) {
+      Assert.fail("The descendant with id " + spoolStageId + " is not the same 
in all parent nodes "
+          + spools);
+    }
   }
 
   @AfterClass
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
index 30550da508..472527ca02 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
@@ -247,7 +247,7 @@ public class QueryEnvironment {
               onServerExplainer, explainPlanVerbose, 
RelBuilder.create(_config));
 
           RelNode explainedNode = 
MultiStageExplainAskingServersUtils.modifyRel(relRoot.rel,
-              dispatchableSubPlan.getQueryStageList(), nodeTracker, 
serversExplainer);
+              dispatchableSubPlan.getQueryStages(), nodeTracker, 
serversExplainer);
 
           String explainStr = PlannerUtils.explainPlan(explainedNode, format, 
level);
 
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlanFragment.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlanFragment.java
index ed64d23765..4152b230f8 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlanFragment.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlanFragment.java
@@ -18,6 +18,7 @@
  */
 package org.apache.pinot.query.planner;
 
+import com.google.common.base.Preconditions;
 import java.util.List;
 import org.apache.pinot.query.planner.plannode.PlanNode;
 
@@ -37,9 +38,16 @@ public class PlanFragment {
   public PlanFragment(int fragmentId, PlanNode fragmentRoot, 
List<PlanFragment> children) {
     _fragmentId = fragmentId;
     _fragmentRoot = fragmentRoot;
+    Preconditions.checkArgument(fragmentRoot.getStageId() == fragmentId,
+        "Fragment root stageId: %s does not match fragmentId: %s", 
fragmentRoot.getStageId(), fragmentId);
     _children = children;
   }
 
+  /**
+   * Returns the fragment id
+   *
+   * <p>Fragment id is the stage id of the fragment root.
+   */
   public int getFragmentId() {
     return _fragmentId;
   }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PhysicalExplainPlanVisitor.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PhysicalExplainPlanVisitor.java
index b91783a186..615ee3ba05 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PhysicalExplainPlanVisitor.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PhysicalExplainPlanVisitor.java
@@ -67,16 +67,16 @@ public class PhysicalExplainPlanVisitor implements 
PlanNodeVisitor<StringBuilder
    * @return a String representation of the query plan tree
    */
   public static String explain(DispatchableSubPlan dispatchableSubPlan) {
-    if (dispatchableSubPlan.getQueryStageList().isEmpty()) {
+    if (dispatchableSubPlan.getQueryStageMap().isEmpty()) {
       return "EMPTY";
     }
 
     // the root of a query plan always only has a single node
     QueryServerInstance rootServer =
-        
dispatchableSubPlan.getQueryStageList().get(0).getServerInstanceToWorkerIdMap()
+        
dispatchableSubPlan.getQueryStageMap().get(0).getServerInstanceToWorkerIdMap()
             .keySet().iterator().next();
     return explainFrom(dispatchableSubPlan,
-        
dispatchableSubPlan.getQueryStageList().get(0).getPlanFragment().getFragmentRoot(),
 rootServer);
+        
dispatchableSubPlan.getQueryStageMap().get(0).getPlanFragment().getFragmentRoot(),
 rootServer);
   }
 
   /**
@@ -173,7 +173,7 @@ public class PhysicalExplainPlanVisitor implements 
PlanNodeVisitor<StringBuilder
 
     MailboxSendNode sender = node.getSender();
     int senderStageId = node.getSenderStageId();
-    DispatchablePlanFragment dispatchablePlanFragment = 
_dispatchableSubPlan.getQueryStageList().get(senderStageId);
+    DispatchablePlanFragment dispatchablePlanFragment = 
_dispatchableSubPlan.getQueryStageMap().get(senderStageId);
 
     Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap =
         dispatchablePlanFragment.getServerInstanceToWorkerIdMap();
@@ -219,7 +219,7 @@ public class PhysicalExplainPlanVisitor implements 
PlanNodeVisitor<StringBuilder
     // This iterator is guaranteed to be sorted by stageId
     for (Integer receiverStageId : node.getReceiverStageIds()) {
       List<MailboxInfo> receiverMailboxInfos =
-          
_dispatchableSubPlan.getQueryStageList().get(node.getStageId()).getWorkerMetadataList().get(context._workerId)
+          
_dispatchableSubPlan.getQueryStageMap().get(node.getStageId()).getWorkerMetadataList().get(context._workerId)
               .getMailboxInfosMap().get(receiverStageId).getMailboxInfos();
       // Sort to ensure print order
       Stream<String> stageDescriptions = receiverMailboxInfos.stream()
@@ -248,7 +248,7 @@ public class PhysicalExplainPlanVisitor implements 
PlanNodeVisitor<StringBuilder
   public StringBuilder visitTableScan(TableScanNode node, Context context) {
     return appendInfo(node, context)
         .append(' ')
-        .append(_dispatchableSubPlan.getQueryStageList()
+        .append(_dispatchableSubPlan.getQueryStageMap()
             .get(node.getStageId())
             .getWorkerIdToSegmentsMap()
             .get(context._host))
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java
index fbd77c4852..5aaf277858 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java
@@ -19,11 +19,14 @@
 package org.apache.pinot.query.planner.physical;
 
 import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Queue;
 import java.util.Set;
 import org.apache.calcite.runtime.PairList;
 import org.apache.pinot.query.context.PlannerContext;
@@ -33,9 +36,12 @@ import org.apache.pinot.query.routing.MailboxInfos;
 import org.apache.pinot.query.routing.QueryServerInstance;
 import org.apache.pinot.query.routing.WorkerManager;
 import org.apache.pinot.query.routing.WorkerMetadata;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 
 public class DispatchablePlanContext {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(DispatchablePlanContext.class);
   private final WorkerManager _workerManager;
 
   private final long _requestId;
@@ -86,10 +92,8 @@ public class DispatchablePlanContext {
     return _dispatchablePlanStageRootMap;
   }
 
-  public List<DispatchablePlanFragment> 
constructDispatchablePlanFragmentList(PlanFragment subPlanRoot) {
-    DispatchablePlanFragment[] dispatchablePlanFragmentArray =
-        new DispatchablePlanFragment[_dispatchablePlanStageRootMap.size()];
-    createDispatchablePlanFragmentList(dispatchablePlanFragmentArray, 
subPlanRoot);
+  public Map<Integer, DispatchablePlanFragment> 
constructDispatchablePlanFragmentMap(PlanFragment subPlanRoot) {
+    Map<Integer, DispatchablePlanFragment> dispatchablePlanFragmentMap = 
createDispatchablePlanFragmentMap(subPlanRoot);
     for (Map.Entry<Integer, DispatchablePlanMetadata> planMetadataEntry : 
_dispatchablePlanMetadataMap.entrySet()) {
       int stageId = planMetadataEntry.getKey();
       DispatchablePlanMetadata dispatchablePlanMetadata = 
planMetadataEntry.getValue();
@@ -115,7 +119,7 @@ public class DispatchablePlanContext {
       }
 
       // set the stageMetadata
-      DispatchablePlanFragment dispatchablePlanFragment = 
dispatchablePlanFragmentArray[stageId];
+      DispatchablePlanFragment dispatchablePlanFragment = 
dispatchablePlanFragmentMap.get(stageId);
       
dispatchablePlanFragment.setWorkerMetadataList(Arrays.asList(workerMetadataArray));
       if (workerIdToSegmentsMap != null) {
         
dispatchablePlanFragment.setWorkerIdToSegmentsMap(workerIdToSegmentsMap);
@@ -130,14 +134,26 @@ public class DispatchablePlanContext {
         
dispatchablePlanFragment.setTimeBoundaryInfo(dispatchablePlanMetadata.getTimeBoundaryInfo());
       }
     }
-    return Arrays.asList(dispatchablePlanFragmentArray);
+    return dispatchablePlanFragmentMap;
   }
 
-  private void createDispatchablePlanFragmentList(DispatchablePlanFragment[] 
dispatchablePlanFragmentArray,
-      PlanFragment planFragmentRoot) {
-    dispatchablePlanFragmentArray[planFragmentRoot.getFragmentId()] = new 
DispatchablePlanFragment(planFragmentRoot);
-    for (PlanFragment childPlanFragment : planFragmentRoot.getChildren()) {
-      createDispatchablePlanFragmentList(dispatchablePlanFragmentArray, 
childPlanFragment);
+  private Map<Integer, DispatchablePlanFragment> 
createDispatchablePlanFragmentMap(PlanFragment planFragmentRoot) {
+    HashMap<Integer, DispatchablePlanFragment> result =
+        Maps.newHashMapWithExpectedSize(_dispatchablePlanMetadataMap.size());
+    Queue<PlanFragment> pendingPlanFragmentIds = new ArrayDeque<>();
+    pendingPlanFragmentIds.add(planFragmentRoot);
+    while (!pendingPlanFragmentIds.isEmpty()) {
+      PlanFragment planFragment = pendingPlanFragmentIds.poll();
+      int planFragmentId = planFragment.getFragmentId();
+
+      if (result.containsKey(planFragmentId)) { // this can happen if some 
stage is spooled.
+        LOGGER.debug("Skipping already visited stage {}", planFragmentId);
+        continue;
+      }
+      result.put(planFragmentId, new DispatchablePlanFragment(planFragment));
+
+      pendingPlanFragmentIds.addAll(planFragment.getChildren());
     }
+    return result;
   }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchableSubPlan.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchableSubPlan.java
index 5299b08ce7..83458f6cc7 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchableSubPlan.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchableSubPlan.java
@@ -18,9 +18,12 @@
  */
 package org.apache.pinot.query.planner.physical;
 
+import java.util.Comparator;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.SortedSet;
+import java.util.TreeSet;
 import org.apache.calcite.runtime.PairList;
 import org.apache.pinot.core.util.QueryMultiThreadingUtils;
 
@@ -40,24 +43,59 @@ import org.apache.pinot.core.util.QueryMultiThreadingUtils;
  */
 public class DispatchableSubPlan {
   private final PairList<Integer, String> _queryResultFields;
-  private final List<DispatchablePlanFragment> _queryStageList;
+
+  /**
+   * Map from stage id to stage plan.
+   */
+  private final Map<Integer, DispatchablePlanFragment> _queryStageMap;
   private final Set<String> _tableNames;
   private final Map<String, Set<String>> _tableToUnavailableSegmentsMap;
 
-  public DispatchableSubPlan(PairList<Integer, String> fields, 
List<DispatchablePlanFragment> queryStageList,
+  public DispatchableSubPlan(PairList<Integer, String> fields,
+      Map<Integer, DispatchablePlanFragment> queryStageMap,
       Set<String> tableNames, Map<String, Set<String>> 
tableToUnavailableSegmentsMap) {
     _queryResultFields = fields;
-    _queryStageList = queryStageList;
+    _queryStageMap = queryStageMap;
     _tableNames = tableNames;
     _tableToUnavailableSegmentsMap = tableToUnavailableSegmentsMap;
   }
 
   /**
-   * Get the list of stage plan root node.
+   * Get a map from stage id to stage plan.
    * @return stage plan map.
    */
-  public List<DispatchablePlanFragment> getQueryStageList() {
-    return _queryStageList;
+  public Map<Integer, DispatchablePlanFragment> getQueryStageMap() {
+    return _queryStageMap;
+  }
+
+  private static Comparator<DispatchablePlanFragment> byStageIdComparator() {
+    return Comparator.comparing(d -> d.getPlanFragment().getFragmentId());
+  }
+
+  /**
+   * Get the query stages.
+   *
+   * The returned set is sorted by stage id.
+   */
+  public SortedSet<DispatchablePlanFragment> getQueryStages() {
+    TreeSet<DispatchablePlanFragment> treeSet = new 
TreeSet<>(byStageIdComparator());
+    treeSet.addAll(_queryStageMap.values());
+    return treeSet;
+  }
+
+  /**
+   * Get the query stages without the root stage.
+   *
+   * The returned set is sorted by stage id.
+   */
+  public SortedSet<DispatchablePlanFragment> getQueryStagesWithoutRoot() {
+    SortedSet<DispatchablePlanFragment> result = getQueryStages();
+
+    DispatchablePlanFragment root = _queryStageMap.get(0);
+    if (root != null) {
+      result.remove(root);
+    }
+    return result;
   }
 
   /**
@@ -90,7 +128,7 @@ public class DispatchableSubPlan {
   public int getEstimatedNumQueryThreads() {
     int estimatedNumQueryThreads = 0;
     // Skip broker reduce root stage
-    for (DispatchablePlanFragment stage : _queryStageList.subList(1, 
_queryStageList.size())) {
+    for (DispatchablePlanFragment stage : getQueryStagesWithoutRoot()) {
       // Non-leaf stage
       if (stage.getWorkerIdToSegmentsMap().isEmpty()) {
         estimatedNumQueryThreads += stage.getWorkerMetadataList().size();
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java
index 0828aa49ff..80ead2db2e 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java
@@ -100,7 +100,7 @@ public class PinotDispatchPlanner {
   private static DispatchableSubPlan finalizeDispatchableSubPlan(PlanFragment 
subPlanRoot,
       DispatchablePlanContext dispatchablePlanContext) {
     return new DispatchableSubPlan(dispatchablePlanContext.getResultFields(),
-        
dispatchablePlanContext.constructDispatchablePlanFragmentList(subPlanRoot),
+        
dispatchablePlanContext.constructDispatchablePlanFragmentMap(subPlanRoot),
         dispatchablePlanContext.getTableNames(),
         
populateTableUnavailableSegments(dispatchablePlanContext.getDispatchablePlanMetadataMap()));
   }
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index e490fb6b12..0a9e02f40c 100644
--- 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -24,6 +24,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.locks.Lock;
 import java.util.concurrent.locks.ReentrantLock;
 import java.util.stream.Collectors;
@@ -94,9 +95,9 @@ public class QueryCompilationTest extends 
QueryEnvironmentTestBase {
   }
 
   private static void assertGroupBySingletonAfterJoin(DispatchableSubPlan 
dispatchableSubPlan, boolean shouldRewrite) {
-    for (int stageId = 0; stageId < 
dispatchableSubPlan.getQueryStageList().size(); stageId++) {
+    for (int stageId = 0; stageId < 
dispatchableSubPlan.getQueryStageMap().size(); stageId++) {
       if (dispatchableSubPlan.getTableNames().size() == 0 && 
!PlannerUtils.isRootPlanFragment(stageId)) {
-        PlanNode node = 
dispatchableSubPlan.getQueryStageList().get(stageId).getPlanFragment().getFragmentRoot();
+        PlanNode node = 
dispatchableSubPlan.getQueryStageMap().get(stageId).getPlanFragment().getFragmentRoot();
         while (node != null) {
           if (node instanceof JoinNode) {
             // JOIN is exchanged with hash distribution (data shuffle)
@@ -126,11 +127,11 @@ public class QueryCompilationTest extends 
QueryEnvironmentTestBase {
   public void testQueryAndAssertStageContentForJoin() {
     String query = "SELECT * FROM a JOIN b ON a.col1 = b.col2";
     DispatchableSubPlan dispatchableSubPlan = 
_queryEnvironment.planQuery(query);
-    List<DispatchablePlanFragment> stagePlans = 
dispatchableSubPlan.getQueryStageList();
+    Set<DispatchablePlanFragment> stagePlans = 
dispatchableSubPlan.getQueryStages();
     int numStages = stagePlans.size();
     assertEquals(numStages, 4);
-    for (int stageId = 0; stageId < numStages; stageId++) {
-      DispatchablePlanFragment stagePlan = stagePlans.get(stageId);
+    for (DispatchablePlanFragment stagePlan : stagePlans) {
+      int stageId = stagePlan.getPlanFragment().getFragmentId();
       Map<QueryServerInstance, List<Integer>> serverToWorkerIdsMap = 
stagePlan.getServerInstanceToWorkerIdMap();
       int numServers = serverToWorkerIdsMap.size();
       String tableName = stagePlan.getTableName();
@@ -166,9 +167,9 @@ public class QueryCompilationTest extends 
QueryEnvironmentTestBase {
     String query = "SELECT a.col1, a.ts, b.col2, b.col3 FROM a JOIN b ON 
a.col1 = b.col2 "
         + "WHERE a.col3 >= 0 AND a.col2 IN ('b') AND b.col3 < 0";
     DispatchableSubPlan dispatchableSubPlan = 
_queryEnvironment.planQuery(query);
-    List<DispatchablePlanFragment> intermediateStages =
-        dispatchableSubPlan.getQueryStageList().stream().filter(q -> 
q.getTableName() == null)
-            .collect(Collectors.toList());
+    List<DispatchablePlanFragment> intermediateStages = 
dispatchableSubPlan.getQueryStageMap().values().stream()
+        .filter(q -> q.getTableName() == null)
+        .collect(Collectors.toList());
     // Assert that no project of filter node for any intermediate stage 
because all should've been pushed down.
     for (DispatchablePlanFragment dispatchablePlanFragment : 
intermediateStages) {
       PlanNode roots = 
dispatchablePlanFragment.getPlanFragment().getFragmentRoot();
@@ -180,25 +181,25 @@ public class QueryCompilationTest extends 
QueryEnvironmentTestBase {
   public void testQueryRoutingManagerCompilation() {
     String query = "SELECT * FROM d_OFFLINE";
     DispatchableSubPlan dispatchableSubPlan = 
_queryEnvironment.planQuery(query);
-    List<DispatchablePlanFragment> tableScanMetadataList =
-        dispatchableSubPlan.getQueryStageList().stream().filter(stageMetadata 
-> stageMetadata.getTableName() != null)
-            .collect(Collectors.toList());
+    List<DispatchablePlanFragment> tableScanMetadataList = 
dispatchableSubPlan.getQueryStageMap().values().stream()
+        .filter(stageMetadata -> stageMetadata.getTableName() != null)
+        .collect(Collectors.toList());
     assertEquals(tableScanMetadataList.size(), 1);
     
assertEquals(tableScanMetadataList.get(0).getServerInstanceToWorkerIdMap().size(),
 2);
 
     query = "SELECT * FROM d_REALTIME";
     dispatchableSubPlan = _queryEnvironment.planQuery(query);
-    tableScanMetadataList =
-        dispatchableSubPlan.getQueryStageList().stream().filter(stageMetadata 
-> stageMetadata.getTableName() != null)
-            .collect(Collectors.toList());
+    tableScanMetadataList = 
dispatchableSubPlan.getQueryStageMap().values().stream()
+        .filter(stageMetadata -> stageMetadata.getTableName() != null)
+        .collect(Collectors.toList());
     assertEquals(tableScanMetadataList.size(), 1);
     
assertEquals(tableScanMetadataList.get(0).getServerInstanceToWorkerIdMap().size(),
 1);
 
     query = "SELECT * FROM d";
     dispatchableSubPlan = _queryEnvironment.planQuery(query);
-    tableScanMetadataList =
-        dispatchableSubPlan.getQueryStageList().stream().filter(stageMetadata 
-> stageMetadata.getTableName() != null)
-            .collect(Collectors.toList());
+    tableScanMetadataList = 
dispatchableSubPlan.getQueryStageMap().values().stream()
+        .filter(stageMetadata -> stageMetadata.getTableName() != null)
+        .collect(Collectors.toList());
     assertEquals(tableScanMetadataList.size(), 1);
     
assertEquals(tableScanMetadataList.get(0).getServerInstanceToWorkerIdMap().size(),
 2);
   }
@@ -260,11 +261,11 @@ public class QueryCompilationTest extends 
QueryEnvironmentTestBase {
     String query =
         "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ 
col1, COUNT(*) FROM b GROUP BY col1";
     DispatchableSubPlan dispatchableSubPlan = 
_queryEnvironment.planQuery(query);
-    List<DispatchablePlanFragment> stagePlans = 
dispatchableSubPlan.getQueryStageList();
+    Set<DispatchablePlanFragment> stagePlans = 
dispatchableSubPlan.getQueryStages();
     int numStages = stagePlans.size();
     assertEquals(numStages, 2);
-    for (int stageId = 0; stageId < numStages; stageId++) {
-      DispatchablePlanFragment stagePlan = stagePlans.get(stageId);
+    for (DispatchablePlanFragment stagePlan : stagePlans) {
+      int stageId = stagePlan.getPlanFragment().getFragmentId();
       Map<QueryServerInstance, List<Integer>> serverToWorkerIdsMap = 
stagePlan.getServerInstanceToWorkerIdMap();
       int numServers = serverToWorkerIdsMap.size();
       String tableName = stagePlan.getTableName();
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/PlanNodeSerDeTest.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/PlanNodeSerDeTest.java
index 1f1825fb0a..8fc02e7812 100644
--- 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/PlanNodeSerDeTest.java
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/PlanNodeSerDeTest.java
@@ -32,7 +32,7 @@ public class PlanNodeSerDeTest extends 
QueryEnvironmentTestBase {
   @Test(dataProvider = "testQueryDataProvider")
   public void testQueryStagePlanSerDe(String query) {
     DispatchableSubPlan dispatchableSubPlan = 
_queryEnvironment.planQuery(query);
-    for (DispatchablePlanFragment dispatchablePlanFragment : 
dispatchableSubPlan.getQueryStageList()) {
+    for (DispatchablePlanFragment dispatchablePlanFragment : 
dispatchableSubPlan.getQueryStages()) {
       PlanNode stagePlan = 
dispatchablePlanFragment.getPlanFragment().getFragmentRoot();
       PlanNode deserializedStagePlan = 
PlanNodeDeserializer.process(PlanNodeSerializer.process(stagePlan));
       assertEquals(stagePlan, deserializedStagePlan);
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/MultiStageStatsTreeBuilder.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/MultiStageStatsTreeBuilder.java
index cd48518e2f..6bad667550 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/MultiStageStatsTreeBuilder.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/MultiStageStatsTreeBuilder.java
@@ -19,8 +19,9 @@
 package org.apache.pinot.query.runtime;
 
 import com.fasterxml.jackson.databind.node.ObjectNode;
-import java.util.ArrayList;
+import com.google.common.collect.Maps;
 import java.util.List;
+import java.util.Map;
 import org.apache.pinot.query.planner.physical.DispatchablePlanFragment;
 import org.apache.pinot.query.planner.plannode.PlanNode;
 import org.apache.pinot.query.runtime.plan.MultiStageQueryStats;
@@ -28,16 +29,16 @@ import org.apache.pinot.spi.utils.JsonUtils;
 
 
 public class MultiStageStatsTreeBuilder {
-  private final List<PlanNode> _planNodes;
+  private final Map<Integer, PlanNode> _planNodes;
   private final List<? extends MultiStageQueryStats.StageStats> _queryStats;
-  private final List<DispatchablePlanFragment> _planFragments;
+  private final Map<Integer, DispatchablePlanFragment> _planFragments;
 
-  public MultiStageStatsTreeBuilder(List<DispatchablePlanFragment> 
planFragments,
+  public MultiStageStatsTreeBuilder(Map<Integer, DispatchablePlanFragment> 
planFragments,
       List<? extends MultiStageQueryStats.StageStats> queryStats) {
     _planFragments = planFragments;
-    _planNodes = new ArrayList<>(planFragments.size());
-    for (DispatchablePlanFragment stagePlan : planFragments) {
-      _planNodes.add(stagePlan.getPlanFragment().getFragmentRoot());
+    _planNodes = Maps.newHashMapWithExpectedSize(planFragments.size());
+    for (Map.Entry<Integer, DispatchablePlanFragment> entry : 
planFragments.entrySet()) {
+      _planNodes.put(entry.getKey(), 
entry.getValue().getPlanFragment().getFragmentRoot());
     }
     _queryStats = queryStats;
   }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
index 38d617a255..bd596b116d 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
@@ -20,6 +20,7 @@ package org.apache.pinot.query.service.dispatch;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.InvalidProtocolBufferException;
 import io.grpc.ConnectivityState;
@@ -44,6 +45,7 @@ import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import javax.annotation.Nullable;
 import org.apache.calcite.runtime.PairList;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.config.TlsConfig;
 import org.apache.pinot.common.datablock.DataBlock;
 import org.apache.pinot.common.exception.QueryException;
@@ -158,7 +160,7 @@ public class QueryDispatcher {
     long requestId = context.getRequestId();
     List<PlanNode> planNodes = new ArrayList<>();
 
-    List<DispatchablePlanFragment> plans = Collections.singletonList(fragment);
+    Set<DispatchablePlanFragment> plans = Collections.singleton(fragment);
     Set<QueryServerInstance> servers = new HashSet<>();
     try {
       SendRequest<List<Worker.ExplainResponse>> requestSender = 
DispatchClient::explain;
@@ -199,8 +201,7 @@ public class QueryDispatcher {
       Map<String, String> queryOptions)
       throws Exception {
     SendRequest<Worker.QueryResponse> requestSender = DispatchClient::submit;
-    List<DispatchablePlanFragment> stagePlans = 
dispatchableSubPlan.getQueryStageList();
-    List<DispatchablePlanFragment> plansWithoutRoot = stagePlans.subList(1, 
stagePlans.size());
+    Set<DispatchablePlanFragment> plansWithoutRoot = 
dispatchableSubPlan.getQueryStagesWithoutRoot();
     execute(requestId, plansWithoutRoot, timeoutMs, queryOptions, 
requestSender, serversOut,
         (response, serverInstance) -> {
       if 
(response.containsMetadata(CommonConstants.Query.Response.ServerResponseStatus.STATUS_ERROR))
 {
@@ -244,7 +245,7 @@ public class QueryDispatcher {
     return _serversByQuery != null;
   }
 
-  private <E> void execute(long requestId, List<DispatchablePlanFragment> 
stagePlans,
+  private <E> void execute(long requestId, Set<DispatchablePlanFragment> 
stagePlans,
       long timeoutMs, Map<String, String> queryOptions,
       SendRequest<E> sendRequest, Set<QueryServerInstance> serverInstancesOut,
       BiConsumer<E, QueryServerInstance> resultConsumer)
@@ -252,7 +253,8 @@ public class QueryDispatcher {
 
     Deadline deadline = Deadline.after(timeoutMs, TimeUnit.MILLISECONDS);
 
-    List<StageInfo> stageInfos = serializePlanFragments(stagePlans, 
serverInstancesOut, deadline);
+    Map<DispatchablePlanFragment, StageInfo> stageInfos =
+        serializePlanFragments(stagePlans, serverInstancesOut, deadline);
 
     if (serverInstancesOut.isEmpty()) {
       throw new RuntimeException("No server instances to dispatch query to");
@@ -272,8 +274,7 @@ public class QueryDispatcher {
               serverInstance);
         }
       };
-      Worker.QueryRequest requestBuilder =
-          createRequest(serverInstance, stagePlans, stageInfos, 
protoRequestMetadata);
+      Worker.QueryRequest requestBuilder = createRequest(serverInstance, 
stageInfos, protoRequestMetadata);
       DispatchClient dispatchClient = 
getOrCreateDispatchClient(serverInstance);
 
       try {
@@ -333,11 +334,12 @@ public class QueryDispatcher {
   }
 
   private static Worker.QueryRequest createRequest(QueryServerInstance 
serverInstance,
-      List<DispatchablePlanFragment> stagePlans, List<StageInfo> stageInfos, 
ByteString protoRequestMetadata) {
+      Map<DispatchablePlanFragment, StageInfo> stageInfos, ByteString 
protoRequestMetadata) {
     Worker.QueryRequest.Builder requestBuilder = 
Worker.QueryRequest.newBuilder();
     
requestBuilder.setVersion(CommonConstants.MultiStageQueryRunner.PlanVersions.V1);
-    for (int i = 0; i < stagePlans.size(); i++) {
-      DispatchablePlanFragment stagePlan = stagePlans.get(i);
+
+    for (Map.Entry<DispatchablePlanFragment, StageInfo> entry : 
stageInfos.entrySet()) {
+      DispatchablePlanFragment stagePlan = entry.getKey();
       List<Integer> workerIds = 
stagePlan.getServerInstanceToWorkerIdMap().get(serverInstance);
       if (workerIds != null) { // otherwise this server doesn't need to 
execute this stage
         List<WorkerMetadata> stageWorkerMetadataList = 
stagePlan.getWorkerMetadataList();
@@ -347,21 +349,18 @@ public class QueryDispatcher {
         }
         List<Worker.WorkerMetadata> protoWorkerMetadataList =
             QueryPlanSerDeUtils.toProtoWorkerMetadataList(workerMetadataList);
-        StageInfo stageInfo = stageInfos.get(i);
+        StageInfo stageInfo = entry.getValue();
 
-        //@formatter:off
         Worker.StagePlan requestStagePlan = Worker.StagePlan.newBuilder()
             .setRootNode(stageInfo._rootNode)
             .setStageMetadata(
                 Worker.StageMetadata.newBuilder()
-                    // this is a leak from submitAndReduce (id may be 
different in explain), but it's fine for now
-                    .setStageId(i + 1)
+                    .setStageId(stagePlan.getPlanFragment().getFragmentId())
                     .addAllWorkerMetadata(protoWorkerMetadataList)
                     .setCustomProperty(stageInfo._customProperty)
                     .build()
             )
             .build();
-        //@formatter:on
         requestBuilder.addStagePlan(requestStagePlan);
       }
     }
@@ -379,18 +378,22 @@ public class QueryDispatcher {
     return requestMetadata;
   }
 
-  private List<StageInfo> 
serializePlanFragments(List<DispatchablePlanFragment> stagePlans,
+  private Map<DispatchablePlanFragment, StageInfo> serializePlanFragments(
+      Set<DispatchablePlanFragment> stagePlans,
       Set<QueryServerInstance> serverInstances, Deadline deadline)
       throws InterruptedException, ExecutionException, TimeoutException {
-    List<CompletableFuture<StageInfo>> stageInfoFutures = new 
ArrayList<>(stagePlans.size());
+    List<CompletableFuture<Pair<DispatchablePlanFragment, StageInfo>>> 
stageInfoFutures =
+        new ArrayList<>(stagePlans.size());
     for (DispatchablePlanFragment stagePlan : stagePlans) {
       
serverInstances.addAll(stagePlan.getServerInstanceToWorkerIdMap().keySet());
-      stageInfoFutures.add(CompletableFuture.supplyAsync(() -> 
serializePlanFragment(stagePlan), _executorService));
+      stageInfoFutures.add(
+          CompletableFuture.supplyAsync(() -> Pair.of(stagePlan, 
serializePlanFragment(stagePlan)), _executorService));
     }
-    List<StageInfo> stageInfos = new ArrayList<>(stagePlans.size());
+    Map<DispatchablePlanFragment, StageInfo> stageInfos = 
Maps.newHashMapWithExpectedSize(stagePlans.size());
     try {
-      for (CompletableFuture<StageInfo> future : stageInfoFutures) {
-        
stageInfos.add(future.get(deadline.timeRemaining(TimeUnit.MILLISECONDS), 
TimeUnit.MILLISECONDS));
+      for (CompletableFuture<Pair<DispatchablePlanFragment, StageInfo>> future 
: stageInfoFutures) {
+        Pair<DispatchablePlanFragment, StageInfo> pair = future.get();
+        stageInfos.put(pair.getKey(), pair.getValue());
       }
     } finally {
       for (CompletableFuture<?> future : stageInfoFutures) {
@@ -469,7 +472,7 @@ public class QueryDispatcher {
     long startTimeMs = System.currentTimeMillis();
     long deadlineMs = startTimeMs + timeoutMs;
     // NOTE: Reduce stage is always stage 0
-    DispatchablePlanFragment stagePlan = subPlan.getQueryStageList().get(0);
+    DispatchablePlanFragment stagePlan = subPlan.getQueryStageMap().get(0);
     PlanFragment planFragment = stagePlan.getPlanFragment();
     PlanNode rootNode = planFragment.getFragmentRoot();
 
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTestBase.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTestBase.java
index a00b6bbcda..b40192e4e7 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTestBase.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTestBase.java
@@ -38,6 +38,7 @@ import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
@@ -131,13 +132,12 @@ public abstract class QueryRunnerTestBase extends 
QueryTestSet {
     }
 
     // Submission Stub logic are mimic {@link QueryServer}
-    List<DispatchablePlanFragment> stagePlans = 
dispatchableSubPlan.getQueryStageList();
+    Set<DispatchablePlanFragment> stagePlans = 
dispatchableSubPlan.getQueryStagesWithoutRoot();
     List<CompletableFuture<?>> submissionStubs = new ArrayList<>();
-    for (int stageId = 0; stageId < stagePlans.size(); stageId++) {
-      if (stageId != 0) {
-        
submissionStubs.addAll(processDistributedStagePlans(dispatchableSubPlan, 
requestId, stageId,
-            requestMetadataMap));
-      }
+    for (DispatchablePlanFragment stagePlan : stagePlans) {
+      int stageId = stagePlan.getPlanFragment().getFragmentId();
+      submissionStubs.addAll(processDistributedStagePlans(dispatchableSubPlan, 
requestId, stageId,
+          requestMetadataMap));
     }
     try {
       CompletableFuture.allOf(submissionStubs.toArray(new 
CompletableFuture[0])).get(timeoutMs, TimeUnit.MILLISECONDS);
@@ -159,7 +159,7 @@ public abstract class QueryRunnerTestBase extends 
QueryTestSet {
 
   protected List<CompletableFuture<?>> 
processDistributedStagePlans(DispatchableSubPlan dispatchableSubPlan,
       long requestId, int stageId, Map<String, String> requestMetadataMap) {
-    DispatchablePlanFragment dispatchableStagePlan = 
dispatchableSubPlan.getQueryStageList().get(stageId);
+    DispatchablePlanFragment dispatchableStagePlan = 
dispatchableSubPlan.getQueryStageMap().get(stageId);
     List<WorkerMetadata> stageWorkerMetadataList = 
dispatchableStagePlan.getWorkerMetadataList();
     List<CompletableFuture<?>> submissionStubs = new ArrayList<>();
     for (Map.Entry<QueryServerInstance, List<Integer>> entry : 
dispatchableStagePlan.getServerInstanceToWorkerIdMap()
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
index f94d85c92b..7d78bd6f94 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
@@ -287,7 +287,7 @@ public class ResourceBasedQueriesTest extends 
QueryRunnerTestBase {
 
   private Map<String, JsonNode> tableToStats(String sql, 
QueryDispatcher.QueryResult queryResult) {
 
-    List<DispatchablePlanFragment> planNodes = 
planQuery(sql).getQueryPlan().getQueryStageList();
+    Map<Integer, DispatchablePlanFragment> planNodes = 
planQuery(sql).getQueryPlan().getQueryStageMap();
 
     MultiStageStatsTreeBuilder multiStageStatsTreeBuilder =
         new MultiStageStatsTreeBuilder(planNodes, queryResult.getQueryStats());
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
index 7a14a2a4c6..fed4fbd002 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
@@ -28,6 +28,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Random;
+import java.util.Set;
 import java.util.concurrent.TimeUnit;
 import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
 import org.apache.pinot.common.proto.Plan;
@@ -119,10 +120,10 @@ public class QueryServerTest extends QueryTestSet {
   public void testWorkerAcceptsWorkerRequestCorrect(String sql)
       throws Exception {
     DispatchableSubPlan queryPlan = _queryEnvironment.planQuery(sql);
-    List<DispatchablePlanFragment> stagePlans = queryPlan.getQueryStageList();
-    int numStages = stagePlans.size();
+    Set<DispatchablePlanFragment> stagePlans = 
queryPlan.getQueryStagesWithoutRoot();
     // Ignore reduce stage (stage 0)
-    for (int stageId = 1; stageId < numStages; stageId++) {
+    for (DispatchablePlanFragment stagePlan : stagePlans) {
+      int stageId = stagePlan.getPlanFragment().getFragmentId();
       // only get one worker request out.
       Worker.QueryRequest queryRequest = getQueryRequest(queryPlan, stageId);
       Map<String, String> requestMetadata = 
QueryPlanSerDeUtils.fromProtoProperties(queryRequest.getMetadata());
@@ -131,10 +132,8 @@ public class QueryServerTest extends QueryTestSet {
       Worker.QueryResponse resp = submitRequest(queryRequest, requestMetadata);
       
assertTrue(resp.getMetadataMap().containsKey(CommonConstants.Query.Response.ServerResponseStatus.STATUS_OK));
 
-      DispatchablePlanFragment dispatchableStagePlan = stagePlans.get(stageId);
-      List<WorkerMetadata> workerMetadataList = 
dispatchableStagePlan.getWorkerMetadataList();
-      StageMetadata stageMetadata =
-          new StageMetadata(stageId, workerMetadataList, 
dispatchableStagePlan.getCustomProperties());
+      List<WorkerMetadata> workerMetadataList = 
stagePlan.getWorkerMetadataList();
+      StageMetadata stageMetadata = new StageMetadata(stageId, 
workerMetadataList, stagePlan.getCustomProperties());
 
       // ensure mock query runner received correctly deserialized payload.
       QueryRunner mockRunner = 
_queryRunnerMap.get(Integer.parseInt(requestMetadata.get(KEY_OF_SERVER_INSTANCE_PORT)));
@@ -143,10 +142,10 @@ public class QueryServerTest extends QueryTestSet {
       // since submitRequest is async, we need to wait for the mockRunner to 
receive the query payload.
       TestUtils.waitForCondition(aVoid -> {
         try {
-          verify(mockRunner, 
times(workerMetadataList.size())).processQuery(any(), argThat(stagePlan -> {
-            PlanNode planNode = 
dispatchableStagePlan.getPlanFragment().getFragmentRoot();
-            return planNode.equals(stagePlan.getRootNode()) && 
isStageMetadataEqual(stageMetadata,
-                stagePlan.getStageMetadata());
+          verify(mockRunner, 
times(workerMetadataList.size())).processQuery(any(), argThat(stagePlanArg -> {
+            PlanNode planNode = stagePlan.getPlanFragment().getFragmentRoot();
+            return planNode.equals(stagePlanArg.getRootNode()) && 
isStageMetadataEqual(stageMetadata,
+                stagePlanArg.getStageMetadata());
           }), argThat(requestMetadataMap -> requestId.equals(
               
requestMetadataMap.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID))),
 any());
           return true;
@@ -206,7 +205,7 @@ public class QueryServerTest extends QueryTestSet {
   }
 
   private Worker.QueryRequest getQueryRequest(DispatchableSubPlan queryPlan, 
int stageId) {
-    DispatchablePlanFragment stagePlan = 
queryPlan.getQueryStageList().get(stageId);
+    DispatchablePlanFragment stagePlan = 
queryPlan.getQueryStageMap().get(stageId);
     Plan.PlanNode rootNode = 
PlanNodeSerializer.process(stagePlan.getPlanFragment().getFragmentRoot());
     List<Worker.WorkerMetadata> workerMetadataList =
         
QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan.getWorkerMetadataList());


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to