This is an automated email from the ASF dual-hosted git repository.
gortiz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new fd3b079615 Refactor DispatchableSubPlan to store fragments by stage id
(#15135)
fd3b079615 is described below
commit fd3b0796156d51ce4239cca02fd925bf1728c885
Author: Gonzalo Ortiz Jaureguizar <[email protected]>
AuthorDate: Fri Feb 28 15:06:54 2025 +0100
Refactor DispatchableSubPlan to store fragments by stage id (#15135)
Previously, they were stored on a list whose id was the stage id. This made
sense when stage ids formed a dense set, but spools broke that assumption.
---
.../MultiStageBrokerRequestHandler.java | 6 +-
.../integration/tests/SpoolIntegrationTest.java | 161 ++++++++++++++++++++-
.../org/apache/pinot/query/QueryEnvironment.java | 2 +-
.../apache/pinot/query/planner/PlanFragment.java | 8 +
.../explain/PhysicalExplainPlanVisitor.java | 12 +-
.../planner/physical/DispatchablePlanContext.java | 38 +++--
.../planner/physical/DispatchableSubPlan.java | 52 ++++++-
.../planner/physical/PinotDispatchPlanner.java | 2 +-
.../apache/pinot/query/QueryCompilationTest.java | 41 +++---
.../query/planner/serde/PlanNodeSerDeTest.java | 2 +-
.../query/runtime/MultiStageStatsTreeBuilder.java | 15 +-
.../query/service/dispatch/QueryDispatcher.java | 47 +++---
.../query/runtime/queries/QueryRunnerTestBase.java | 14 +-
.../runtime/queries/ResourceBasedQueriesTest.java | 2 +-
.../query/service/server/QueryServerTest.java | 23 ++-
15 files changed, 321 insertions(+), 104 deletions(-)
diff --git
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
index 771e978341..80dba42f8e 100644
---
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
+++
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
@@ -263,7 +263,7 @@ public class MultiStageBrokerRequestHandler extends
BaseBrokerRequestHandler {
DispatchableSubPlan dispatchableSubPlan = queryPlanResult.getQueryPlan();
Set<QueryServerInstance> servers = new HashSet<>();
- for (DispatchablePlanFragment planFragment:
dispatchableSubPlan.getQueryStageList()) {
+ for (DispatchablePlanFragment planFragment:
dispatchableSubPlan.getQueryStageMap().values()) {
servers.addAll(planFragment.getServerInstances());
}
@@ -443,9 +443,9 @@ public class MultiStageBrokerRequestHandler extends
BaseBrokerRequestHandler {
private void fillOldBrokerResponseStats(BrokerResponseNativeV2
brokerResponse,
List<MultiStageQueryStats.StageStats.Closed> queryStats,
DispatchableSubPlan dispatchableSubPlan) {
try {
- List<DispatchablePlanFragment> stagePlans =
dispatchableSubPlan.getQueryStageList();
+ Map<Integer, DispatchablePlanFragment> queryStageMap =
dispatchableSubPlan.getQueryStageMap();
- MultiStageStatsTreeBuilder treeBuilder = new
MultiStageStatsTreeBuilder(stagePlans, queryStats);
+ MultiStageStatsTreeBuilder treeBuilder = new
MultiStageStatsTreeBuilder(queryStageMap, queryStats);
brokerResponse.setStageStats(treeBuilder.jsonStatsByStage(0));
for (MultiStageQueryStats.StageStats.Closed stageStats : queryStats) {
if (stageStats != null) { // for example pipeline breaker may not have
stats
diff --git
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/SpoolIntegrationTest.java
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/SpoolIntegrationTest.java
index 91f86f4f65..1c37808401 100644
---
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/SpoolIntegrationTest.java
+++
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/SpoolIntegrationTest.java
@@ -22,13 +22,17 @@ import com.fasterxml.jackson.databind.JsonNode;
import com.jayway.jsonpath.DocumentContext;
import com.jayway.jsonpath.JsonPath;
import java.io.File;
+import java.util.Arrays;
import java.util.List;
import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.spi.config.table.TableConfig;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.spi.env.PinotConfiguration;
import org.apache.pinot.spi.utils.CommonConstants;
import org.apache.pinot.util.TestUtils;
+import org.intellij.lang.annotations.Language;
import org.testcontainers.shaded.org.apache.commons.io.FileUtils;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
@@ -117,13 +121,160 @@ public class SpoolIntegrationTest extends
BaseClusterIntegrationTest
JsonNode stats = jsonNode.get("stageStats");
assertNoError(jsonNode);
DocumentContext parsed = JsonPath.parse(stats.toString());
- List<Map<String, Object>> stage4On3 = parsed.read("$..[?(@.stage ==
3)]..[?(@.stage == 4)]");
- Assert.assertEquals(stage4On3.size(), 1, "Stage 4 should be descended from
stage 3 exactly once");
- List<Map<String, Object>> stage4On7 = parsed.read("$..[?(@.stage ==
7)]..[?(@.stage == 4)]");
- Assert.assertEquals(stage4On3.size(), 1, "Stage 4 should be descended from
stage 7 exactly once");
+ checkSpoolTimes(parsed, 4, 3, 1);
+ checkSpoolTimes(parsed, 4, 7, 1);
+ checkSpoolSame(parsed, 4, 3, 7);
+ }
+
+ /**
+ * Test a complex with nested spools.
+ *
+ * Don't try to understand the query, just check that the spools are correct.
+ * This query is an actual simplification of a query used in production.
+ * It was the way we detected problems fixed in <a
href="https://github.com/apache/pinot/pull/15135">#15135</a>.
+ */
+ @Test
+ public void testNestedSpools()
+ throws Exception {
+ JsonNode jsonNode = postQuery("SET useSpools = true;\n"
+ + "\n"
+ + "WITH\n"
+ + " q1 AS (\n"
+ + " SELECT ArrTimeBlk as userUUID,\n"
+ + " Dest as deviceOS,\n"
+ + " SUM(ArrTime) AS totalTrips\n"
+ + " FROM mytable\n"
+ + " GROUP BY ArrTimeBlk, Dest\n"
+ + " ),\n"
+ + " q2 AS (\n"
+ + " SELECT userUUID,\n"
+ + " deviceOS,\n"
+ + " SUM(totalTrips) AS totalTrips,\n"
+ + " COUNT(DISTINCT userUUID) AS reach\n"
+ + " FROM q1\n"
+ + " GROUP BY userUUID,\n"
+ + " deviceOS\n"
+ + " ),\n"
+ + " q3 AS (\n"
+ + " SELECT userUUID,\n"
+ + " (totalTrips / reach) AS frequency\n"
+ + " FROM q2\n"
+ + " ),\n"
+ + " q4 AS (\n"
+ + " SELECT rd.userUUID,\n"
+ + " rd.deviceOS,\n"
+ + " rd.totalTrips as totalTrips,\n"
+ + " rd.reach AS reach\n"
+ + " FROM q2 rd\n"
+ + " ),\n"
+ + " q5 AS (\n"
+ + " SELECT userUUID,\n"
+ + " SUM(totalTrips) AS totalTrips\n"
+ + " FROM q4\n"
+ + " GROUP BY userUUID\n"
+ + " ),\n"
+ + " q6 AS (\n"
+ + " SELECT s.userUUID,\n"
+ + " s.totalTrips,\n"
+ + " (s.totalTrips / o.frequency) AS reach,\n"
+ + " 'some fake device' AS deviceOS\n"
+ + " FROM q5 s\n"
+ + " JOIN q3 o ON s.userUUID = o.userUUID\n"
+ + " ),\n"
+ + " q7 AS (\n"
+ + " SELECT rd.userUUID,\n"
+ + " rd.totalTrips,\n"
+ + " rd.reach,\n"
+ + " rd.deviceOS\n"
+ + " FROM q4 rd\n"
+ + " UNION ALL\n"
+ + " SELECT f.userUUID,\n"
+ + " f.totalTrips,\n"
+ + " f.reach,\n"
+ + " f.deviceOS\n"
+ + " FROM q6 f\n"
+ + " ),\n"
+ + " q8 AS (\n"
+ + " SELECT sd.*\n"
+ + " FROM q7 sd\n"
+ + " JOIN (\n"
+ + " SELECT deviceOS,\n"
+ + " PERCENTILETDigest(totalTrips, 20) AS p20\n"
+ + " FROM q7\n"
+ + " GROUP BY deviceOS\n"
+ + " ) q ON sd.deviceOS = q.deviceOS\n"
+ + " )\n"
+ + "SELECT *\n"
+ + "FROM q8");
+ JsonNode stats = jsonNode.get("stageStats");
+ assertNoError(jsonNode);
+ DocumentContext parsed = JsonPath.parse(stats.toString());
+
+ /*
+ * Stages are like follows:
+ * 1 -> 2 (union) -> 3 (aggr) -> 4 (leaf, spooled)
+ * -> 5 (join) -> 6 (aggr, spooled) -> 7 (aggr,
spooled) -> 4 (leaf, spooled)
+ * -> 9 (aggr) -> 4 (leaf, spooled)
+ * -> 11 (union) -> 12 (aggr) -> 4 (leaf, spooled)
+ * -> 14 (join) -> 6 (aggr, spooled) -> 7 (aggr,
spooled) -> 4 (leaf, spooled)
+ * -> 18 (aggr) -> 4 (leaf, spooled)
+ */
+ checkSpoolTimes(parsed, 6, 5, 1);
+ checkSpoolTimes(parsed, 6, 14, 1);
+ checkSpoolSame(parsed, 6, 5, 14);
+
+ checkSpoolTimes(parsed, 7, 6, 2);
+
+ checkSpoolTimes(parsed, 4, 3, 1);
+ checkSpoolTimes(parsed, 4, 7, 2); // because there are 2 copies of 7 as
well
+ checkSpoolTimes(parsed, 4, 9, 1);
+ checkSpoolTimes(parsed, 4, 12, 1);
+ checkSpoolTimes(parsed, 4, 18, 1);
+ checkSpoolSame(parsed, 4, 3, 7, 9, 12, 18);
+ }
+
+ /**
+ * Returns the nodes that have the given descendant id and also have the
given parent id as one of their ancestors.
+ * @param parent the parent id
+ * @param descendant the descendant id
+ */
+ private List<Map<String, Object>> findDescendantById(DocumentContext stats,
int parent, int descendant) {
+ @Language("jsonpath")
+ String jsonPath = "$..[?(@.stage == " + parent + ")]..[?(@.stage == " +
descendant + ")]";
+ return stats.read(jsonPath);
+ }
+
+ private void checkSpoolTimes(DocumentContext stats, int spoolStageId, int
parent, int times) {
+ List<Map<String, Object>> descendants = findDescendantById(stats, parent,
spoolStageId);
+ Assert.assertEquals(descendants.size(), times, "Stage " + spoolStageId + "
should be descended from stage "
+ + parent + " exactly " + times + " times");
+ Map<String, Object> firstSpool = descendants.get(0);
+ for (int i = 1; i < descendants.size(); i++) {
+ Assert.assertEquals(descendants.get(i), firstSpool, "Stage " +
spoolStageId + " should be the same in "
+ + "all " + times + " descendants");
+ }
+ }
- Assert.assertEquals(stage4On3, stage4On7, "Stage 4 should be the same in
both stage 3 and stage 7");
+ private void checkSpoolSame(DocumentContext stats, int spoolStageId, int...
parents) {
+ List<Pair<Integer, List<Map<String, Object>>>> spools =
Arrays.stream(parents)
+ .mapToObj(parent -> Pair.of(parent, findDescendantById(stats, parent,
spoolStageId)))
+ .collect(Collectors.toList());
+ Pair<Integer, List<Map<String, Object>>> notEmpty = spools.stream()
+ .filter(s -> !s.getValue().isEmpty())
+ .findFirst()
+ .orElse(null);
+ if (notEmpty == null) {
+ Assert.fail("None of the parent nodes " + Arrays.toString(parents) + "
have a descendant with id "
+ + spoolStageId);
+ }
+ List<Pair<Integer, List<Map<String, Object>>>> allNotEqual =
spools.stream()
+ .filter(s -> !s.getValue().get(0).equals(notEmpty.getValue().get(0)))
+ .collect(Collectors.toList());
+ if (!allNotEqual.isEmpty()) {
+ Assert.fail("The descendant with id " + spoolStageId + " is not the same
in all parent nodes "
+ + spools);
+ }
}
@AfterClass
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
index 30550da508..472527ca02 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
@@ -247,7 +247,7 @@ public class QueryEnvironment {
onServerExplainer, explainPlanVerbose,
RelBuilder.create(_config));
RelNode explainedNode =
MultiStageExplainAskingServersUtils.modifyRel(relRoot.rel,
- dispatchableSubPlan.getQueryStageList(), nodeTracker,
serversExplainer);
+ dispatchableSubPlan.getQueryStages(), nodeTracker,
serversExplainer);
String explainStr = PlannerUtils.explainPlan(explainedNode, format,
level);
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlanFragment.java
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlanFragment.java
index ed64d23765..4152b230f8 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlanFragment.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlanFragment.java
@@ -18,6 +18,7 @@
*/
package org.apache.pinot.query.planner;
+import com.google.common.base.Preconditions;
import java.util.List;
import org.apache.pinot.query.planner.plannode.PlanNode;
@@ -37,9 +38,16 @@ public class PlanFragment {
public PlanFragment(int fragmentId, PlanNode fragmentRoot,
List<PlanFragment> children) {
_fragmentId = fragmentId;
_fragmentRoot = fragmentRoot;
+ Preconditions.checkArgument(fragmentRoot.getStageId() == fragmentId,
+ "Fragment root stageId: %s does not match fragmentId: %s",
fragmentRoot.getStageId(), fragmentId);
_children = children;
}
+ /**
+ * Returns the fragment id
+ *
+ * <p>Fragment id is the stage id of the fragment root.
+ */
public int getFragmentId() {
return _fragmentId;
}
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PhysicalExplainPlanVisitor.java
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PhysicalExplainPlanVisitor.java
index b91783a186..615ee3ba05 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PhysicalExplainPlanVisitor.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PhysicalExplainPlanVisitor.java
@@ -67,16 +67,16 @@ public class PhysicalExplainPlanVisitor implements
PlanNodeVisitor<StringBuilder
* @return a String representation of the query plan tree
*/
public static String explain(DispatchableSubPlan dispatchableSubPlan) {
- if (dispatchableSubPlan.getQueryStageList().isEmpty()) {
+ if (dispatchableSubPlan.getQueryStageMap().isEmpty()) {
return "EMPTY";
}
// the root of a query plan always only has a single node
QueryServerInstance rootServer =
-
dispatchableSubPlan.getQueryStageList().get(0).getServerInstanceToWorkerIdMap()
+
dispatchableSubPlan.getQueryStageMap().get(0).getServerInstanceToWorkerIdMap()
.keySet().iterator().next();
return explainFrom(dispatchableSubPlan,
-
dispatchableSubPlan.getQueryStageList().get(0).getPlanFragment().getFragmentRoot(),
rootServer);
+
dispatchableSubPlan.getQueryStageMap().get(0).getPlanFragment().getFragmentRoot(),
rootServer);
}
/**
@@ -173,7 +173,7 @@ public class PhysicalExplainPlanVisitor implements
PlanNodeVisitor<StringBuilder
MailboxSendNode sender = node.getSender();
int senderStageId = node.getSenderStageId();
- DispatchablePlanFragment dispatchablePlanFragment =
_dispatchableSubPlan.getQueryStageList().get(senderStageId);
+ DispatchablePlanFragment dispatchablePlanFragment =
_dispatchableSubPlan.getQueryStageMap().get(senderStageId);
Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap =
dispatchablePlanFragment.getServerInstanceToWorkerIdMap();
@@ -219,7 +219,7 @@ public class PhysicalExplainPlanVisitor implements
PlanNodeVisitor<StringBuilder
// This iterator is guaranteed to be sorted by stageId
for (Integer receiverStageId : node.getReceiverStageIds()) {
List<MailboxInfo> receiverMailboxInfos =
-
_dispatchableSubPlan.getQueryStageList().get(node.getStageId()).getWorkerMetadataList().get(context._workerId)
+
_dispatchableSubPlan.getQueryStageMap().get(node.getStageId()).getWorkerMetadataList().get(context._workerId)
.getMailboxInfosMap().get(receiverStageId).getMailboxInfos();
// Sort to ensure print order
Stream<String> stageDescriptions = receiverMailboxInfos.stream()
@@ -248,7 +248,7 @@ public class PhysicalExplainPlanVisitor implements
PlanNodeVisitor<StringBuilder
public StringBuilder visitTableScan(TableScanNode node, Context context) {
return appendInfo(node, context)
.append(' ')
- .append(_dispatchableSubPlan.getQueryStageList()
+ .append(_dispatchableSubPlan.getQueryStageMap()
.get(node.getStageId())
.getWorkerIdToSegmentsMap()
.get(context._host))
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java
index fbd77c4852..5aaf277858 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java
@@ -19,11 +19,14 @@
package org.apache.pinot.query.planner.physical;
import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Queue;
import java.util.Set;
import org.apache.calcite.runtime.PairList;
import org.apache.pinot.query.context.PlannerContext;
@@ -33,9 +36,12 @@ import org.apache.pinot.query.routing.MailboxInfos;
import org.apache.pinot.query.routing.QueryServerInstance;
import org.apache.pinot.query.routing.WorkerManager;
import org.apache.pinot.query.routing.WorkerMetadata;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
public class DispatchablePlanContext {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(DispatchablePlanContext.class);
private final WorkerManager _workerManager;
private final long _requestId;
@@ -86,10 +92,8 @@ public class DispatchablePlanContext {
return _dispatchablePlanStageRootMap;
}
- public List<DispatchablePlanFragment>
constructDispatchablePlanFragmentList(PlanFragment subPlanRoot) {
- DispatchablePlanFragment[] dispatchablePlanFragmentArray =
- new DispatchablePlanFragment[_dispatchablePlanStageRootMap.size()];
- createDispatchablePlanFragmentList(dispatchablePlanFragmentArray,
subPlanRoot);
+ public Map<Integer, DispatchablePlanFragment>
constructDispatchablePlanFragmentMap(PlanFragment subPlanRoot) {
+ Map<Integer, DispatchablePlanFragment> dispatchablePlanFragmentMap =
createDispatchablePlanFragmentMap(subPlanRoot);
for (Map.Entry<Integer, DispatchablePlanMetadata> planMetadataEntry :
_dispatchablePlanMetadataMap.entrySet()) {
int stageId = planMetadataEntry.getKey();
DispatchablePlanMetadata dispatchablePlanMetadata =
planMetadataEntry.getValue();
@@ -115,7 +119,7 @@ public class DispatchablePlanContext {
}
// set the stageMetadata
- DispatchablePlanFragment dispatchablePlanFragment =
dispatchablePlanFragmentArray[stageId];
+ DispatchablePlanFragment dispatchablePlanFragment =
dispatchablePlanFragmentMap.get(stageId);
dispatchablePlanFragment.setWorkerMetadataList(Arrays.asList(workerMetadataArray));
if (workerIdToSegmentsMap != null) {
dispatchablePlanFragment.setWorkerIdToSegmentsMap(workerIdToSegmentsMap);
@@ -130,14 +134,26 @@ public class DispatchablePlanContext {
dispatchablePlanFragment.setTimeBoundaryInfo(dispatchablePlanMetadata.getTimeBoundaryInfo());
}
}
- return Arrays.asList(dispatchablePlanFragmentArray);
+ return dispatchablePlanFragmentMap;
}
- private void createDispatchablePlanFragmentList(DispatchablePlanFragment[]
dispatchablePlanFragmentArray,
- PlanFragment planFragmentRoot) {
- dispatchablePlanFragmentArray[planFragmentRoot.getFragmentId()] = new
DispatchablePlanFragment(planFragmentRoot);
- for (PlanFragment childPlanFragment : planFragmentRoot.getChildren()) {
- createDispatchablePlanFragmentList(dispatchablePlanFragmentArray,
childPlanFragment);
+ private Map<Integer, DispatchablePlanFragment>
createDispatchablePlanFragmentMap(PlanFragment planFragmentRoot) {
+ HashMap<Integer, DispatchablePlanFragment> result =
+ Maps.newHashMapWithExpectedSize(_dispatchablePlanMetadataMap.size());
+ Queue<PlanFragment> pendingPlanFragmentIds = new ArrayDeque<>();
+ pendingPlanFragmentIds.add(planFragmentRoot);
+ while (!pendingPlanFragmentIds.isEmpty()) {
+ PlanFragment planFragment = pendingPlanFragmentIds.poll();
+ int planFragmentId = planFragment.getFragmentId();
+
+ if (result.containsKey(planFragmentId)) { // this can happen if some
stage is spooled.
+ LOGGER.debug("Skipping already visited stage {}", planFragmentId);
+ continue;
+ }
+ result.put(planFragmentId, new DispatchablePlanFragment(planFragment));
+
+ pendingPlanFragmentIds.addAll(planFragment.getChildren());
}
+ return result;
}
}
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchableSubPlan.java
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchableSubPlan.java
index 5299b08ce7..83458f6cc7 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchableSubPlan.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchableSubPlan.java
@@ -18,9 +18,12 @@
*/
package org.apache.pinot.query.planner.physical;
+import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.SortedSet;
+import java.util.TreeSet;
import org.apache.calcite.runtime.PairList;
import org.apache.pinot.core.util.QueryMultiThreadingUtils;
@@ -40,24 +43,59 @@ import org.apache.pinot.core.util.QueryMultiThreadingUtils;
*/
public class DispatchableSubPlan {
private final PairList<Integer, String> _queryResultFields;
- private final List<DispatchablePlanFragment> _queryStageList;
+
+ /**
+ * Map from stage id to stage plan.
+ */
+ private final Map<Integer, DispatchablePlanFragment> _queryStageMap;
private final Set<String> _tableNames;
private final Map<String, Set<String>> _tableToUnavailableSegmentsMap;
- public DispatchableSubPlan(PairList<Integer, String> fields,
List<DispatchablePlanFragment> queryStageList,
+ public DispatchableSubPlan(PairList<Integer, String> fields,
+ Map<Integer, DispatchablePlanFragment> queryStageMap,
Set<String> tableNames, Map<String, Set<String>>
tableToUnavailableSegmentsMap) {
_queryResultFields = fields;
- _queryStageList = queryStageList;
+ _queryStageMap = queryStageMap;
_tableNames = tableNames;
_tableToUnavailableSegmentsMap = tableToUnavailableSegmentsMap;
}
/**
- * Get the list of stage plan root node.
+ * Get a map from stage id to stage plan.
* @return stage plan map.
*/
- public List<DispatchablePlanFragment> getQueryStageList() {
- return _queryStageList;
+ public Map<Integer, DispatchablePlanFragment> getQueryStageMap() {
+ return _queryStageMap;
+ }
+
+ private static Comparator<DispatchablePlanFragment> byStageIdComparator() {
+ return Comparator.comparing(d -> d.getPlanFragment().getFragmentId());
+ }
+
+ /**
+ * Get the query stages.
+ *
+ * The returned set is sorted by stage id.
+ */
+ public SortedSet<DispatchablePlanFragment> getQueryStages() {
+ TreeSet<DispatchablePlanFragment> treeSet = new
TreeSet<>(byStageIdComparator());
+ treeSet.addAll(_queryStageMap.values());
+ return treeSet;
+ }
+
+ /**
+ * Get the query stages without the root stage.
+ *
+ * The returned set is sorted by stage id.
+ */
+ public SortedSet<DispatchablePlanFragment> getQueryStagesWithoutRoot() {
+ SortedSet<DispatchablePlanFragment> result = getQueryStages();
+
+ DispatchablePlanFragment root = _queryStageMap.get(0);
+ if (root != null) {
+ result.remove(root);
+ }
+ return result;
}
/**
@@ -90,7 +128,7 @@ public class DispatchableSubPlan {
public int getEstimatedNumQueryThreads() {
int estimatedNumQueryThreads = 0;
// Skip broker reduce root stage
- for (DispatchablePlanFragment stage : _queryStageList.subList(1,
_queryStageList.size())) {
+ for (DispatchablePlanFragment stage : getQueryStagesWithoutRoot()) {
// Non-leaf stage
if (stage.getWorkerIdToSegmentsMap().isEmpty()) {
estimatedNumQueryThreads += stage.getWorkerMetadataList().size();
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java
index 0828aa49ff..80ead2db2e 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java
@@ -100,7 +100,7 @@ public class PinotDispatchPlanner {
private static DispatchableSubPlan finalizeDispatchableSubPlan(PlanFragment
subPlanRoot,
DispatchablePlanContext dispatchablePlanContext) {
return new DispatchableSubPlan(dispatchablePlanContext.getResultFields(),
-
dispatchablePlanContext.constructDispatchablePlanFragmentList(subPlanRoot),
+
dispatchablePlanContext.constructDispatchablePlanFragmentMap(subPlanRoot),
dispatchablePlanContext.getTableNames(),
populateTableUnavailableSegments(dispatchablePlanContext.getDispatchablePlanMetadataMap()));
}
diff --git
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index e490fb6b12..0a9e02f40c 100644
---
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
+++
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -24,6 +24,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
@@ -94,9 +95,9 @@ public class QueryCompilationTest extends
QueryEnvironmentTestBase {
}
private static void assertGroupBySingletonAfterJoin(DispatchableSubPlan
dispatchableSubPlan, boolean shouldRewrite) {
- for (int stageId = 0; stageId <
dispatchableSubPlan.getQueryStageList().size(); stageId++) {
+ for (int stageId = 0; stageId <
dispatchableSubPlan.getQueryStageMap().size(); stageId++) {
if (dispatchableSubPlan.getTableNames().size() == 0 &&
!PlannerUtils.isRootPlanFragment(stageId)) {
- PlanNode node =
dispatchableSubPlan.getQueryStageList().get(stageId).getPlanFragment().getFragmentRoot();
+ PlanNode node =
dispatchableSubPlan.getQueryStageMap().get(stageId).getPlanFragment().getFragmentRoot();
while (node != null) {
if (node instanceof JoinNode) {
// JOIN is exchanged with hash distribution (data shuffle)
@@ -126,11 +127,11 @@ public class QueryCompilationTest extends
QueryEnvironmentTestBase {
public void testQueryAndAssertStageContentForJoin() {
String query = "SELECT * FROM a JOIN b ON a.col1 = b.col2";
DispatchableSubPlan dispatchableSubPlan =
_queryEnvironment.planQuery(query);
- List<DispatchablePlanFragment> stagePlans =
dispatchableSubPlan.getQueryStageList();
+ Set<DispatchablePlanFragment> stagePlans =
dispatchableSubPlan.getQueryStages();
int numStages = stagePlans.size();
assertEquals(numStages, 4);
- for (int stageId = 0; stageId < numStages; stageId++) {
- DispatchablePlanFragment stagePlan = stagePlans.get(stageId);
+ for (DispatchablePlanFragment stagePlan : stagePlans) {
+ int stageId = stagePlan.getPlanFragment().getFragmentId();
Map<QueryServerInstance, List<Integer>> serverToWorkerIdsMap =
stagePlan.getServerInstanceToWorkerIdMap();
int numServers = serverToWorkerIdsMap.size();
String tableName = stagePlan.getTableName();
@@ -166,9 +167,9 @@ public class QueryCompilationTest extends
QueryEnvironmentTestBase {
String query = "SELECT a.col1, a.ts, b.col2, b.col3 FROM a JOIN b ON
a.col1 = b.col2 "
+ "WHERE a.col3 >= 0 AND a.col2 IN ('b') AND b.col3 < 0";
DispatchableSubPlan dispatchableSubPlan =
_queryEnvironment.planQuery(query);
- List<DispatchablePlanFragment> intermediateStages =
- dispatchableSubPlan.getQueryStageList().stream().filter(q ->
q.getTableName() == null)
- .collect(Collectors.toList());
+ List<DispatchablePlanFragment> intermediateStages =
dispatchableSubPlan.getQueryStageMap().values().stream()
+ .filter(q -> q.getTableName() == null)
+ .collect(Collectors.toList());
// Assert that no project of filter node for any intermediate stage
because all should've been pushed down.
for (DispatchablePlanFragment dispatchablePlanFragment :
intermediateStages) {
PlanNode roots =
dispatchablePlanFragment.getPlanFragment().getFragmentRoot();
@@ -180,25 +181,25 @@ public class QueryCompilationTest extends
QueryEnvironmentTestBase {
public void testQueryRoutingManagerCompilation() {
String query = "SELECT * FROM d_OFFLINE";
DispatchableSubPlan dispatchableSubPlan =
_queryEnvironment.planQuery(query);
- List<DispatchablePlanFragment> tableScanMetadataList =
- dispatchableSubPlan.getQueryStageList().stream().filter(stageMetadata
-> stageMetadata.getTableName() != null)
- .collect(Collectors.toList());
+ List<DispatchablePlanFragment> tableScanMetadataList =
dispatchableSubPlan.getQueryStageMap().values().stream()
+ .filter(stageMetadata -> stageMetadata.getTableName() != null)
+ .collect(Collectors.toList());
assertEquals(tableScanMetadataList.size(), 1);
assertEquals(tableScanMetadataList.get(0).getServerInstanceToWorkerIdMap().size(),
2);
query = "SELECT * FROM d_REALTIME";
dispatchableSubPlan = _queryEnvironment.planQuery(query);
- tableScanMetadataList =
- dispatchableSubPlan.getQueryStageList().stream().filter(stageMetadata
-> stageMetadata.getTableName() != null)
- .collect(Collectors.toList());
+ tableScanMetadataList =
dispatchableSubPlan.getQueryStageMap().values().stream()
+ .filter(stageMetadata -> stageMetadata.getTableName() != null)
+ .collect(Collectors.toList());
assertEquals(tableScanMetadataList.size(), 1);
assertEquals(tableScanMetadataList.get(0).getServerInstanceToWorkerIdMap().size(),
1);
query = "SELECT * FROM d";
dispatchableSubPlan = _queryEnvironment.planQuery(query);
- tableScanMetadataList =
- dispatchableSubPlan.getQueryStageList().stream().filter(stageMetadata
-> stageMetadata.getTableName() != null)
- .collect(Collectors.toList());
+ tableScanMetadataList =
dispatchableSubPlan.getQueryStageMap().values().stream()
+ .filter(stageMetadata -> stageMetadata.getTableName() != null)
+ .collect(Collectors.toList());
assertEquals(tableScanMetadataList.size(), 1);
assertEquals(tableScanMetadataList.get(0).getServerInstanceToWorkerIdMap().size(),
2);
}
@@ -260,11 +261,11 @@ public class QueryCompilationTest extends
QueryEnvironmentTestBase {
String query =
"SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */
col1, COUNT(*) FROM b GROUP BY col1";
DispatchableSubPlan dispatchableSubPlan =
_queryEnvironment.planQuery(query);
- List<DispatchablePlanFragment> stagePlans =
dispatchableSubPlan.getQueryStageList();
+ Set<DispatchablePlanFragment> stagePlans =
dispatchableSubPlan.getQueryStages();
int numStages = stagePlans.size();
assertEquals(numStages, 2);
- for (int stageId = 0; stageId < numStages; stageId++) {
- DispatchablePlanFragment stagePlan = stagePlans.get(stageId);
+ for (DispatchablePlanFragment stagePlan : stagePlans) {
+ int stageId = stagePlan.getPlanFragment().getFragmentId();
Map<QueryServerInstance, List<Integer>> serverToWorkerIdsMap =
stagePlan.getServerInstanceToWorkerIdMap();
int numServers = serverToWorkerIdsMap.size();
String tableName = stagePlan.getTableName();
diff --git
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/PlanNodeSerDeTest.java
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/PlanNodeSerDeTest.java
index 1f1825fb0a..8fc02e7812 100644
---
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/PlanNodeSerDeTest.java
+++
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/PlanNodeSerDeTest.java
@@ -32,7 +32,7 @@ public class PlanNodeSerDeTest extends
QueryEnvironmentTestBase {
@Test(dataProvider = "testQueryDataProvider")
public void testQueryStagePlanSerDe(String query) {
DispatchableSubPlan dispatchableSubPlan =
_queryEnvironment.planQuery(query);
- for (DispatchablePlanFragment dispatchablePlanFragment :
dispatchableSubPlan.getQueryStageList()) {
+ for (DispatchablePlanFragment dispatchablePlanFragment :
dispatchableSubPlan.getQueryStages()) {
PlanNode stagePlan =
dispatchablePlanFragment.getPlanFragment().getFragmentRoot();
PlanNode deserializedStagePlan =
PlanNodeDeserializer.process(PlanNodeSerializer.process(stagePlan));
assertEquals(stagePlan, deserializedStagePlan);
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/MultiStageStatsTreeBuilder.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/MultiStageStatsTreeBuilder.java
index cd48518e2f..6bad667550 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/MultiStageStatsTreeBuilder.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/MultiStageStatsTreeBuilder.java
@@ -19,8 +19,9 @@
package org.apache.pinot.query.runtime;
import com.fasterxml.jackson.databind.node.ObjectNode;
-import java.util.ArrayList;
+import com.google.common.collect.Maps;
import java.util.List;
+import java.util.Map;
import org.apache.pinot.query.planner.physical.DispatchablePlanFragment;
import org.apache.pinot.query.planner.plannode.PlanNode;
import org.apache.pinot.query.runtime.plan.MultiStageQueryStats;
@@ -28,16 +29,16 @@ import org.apache.pinot.spi.utils.JsonUtils;
public class MultiStageStatsTreeBuilder {
- private final List<PlanNode> _planNodes;
+ private final Map<Integer, PlanNode> _planNodes;
private final List<? extends MultiStageQueryStats.StageStats> _queryStats;
- private final List<DispatchablePlanFragment> _planFragments;
+ private final Map<Integer, DispatchablePlanFragment> _planFragments;
- public MultiStageStatsTreeBuilder(List<DispatchablePlanFragment>
planFragments,
+ public MultiStageStatsTreeBuilder(Map<Integer, DispatchablePlanFragment>
planFragments,
List<? extends MultiStageQueryStats.StageStats> queryStats) {
_planFragments = planFragments;
- _planNodes = new ArrayList<>(planFragments.size());
- for (DispatchablePlanFragment stagePlan : planFragments) {
- _planNodes.add(stagePlan.getPlanFragment().getFragmentRoot());
+ _planNodes = Maps.newHashMapWithExpectedSize(planFragments.size());
+ for (Map.Entry<Integer, DispatchablePlanFragment> entry :
planFragments.entrySet()) {
+ _planNodes.put(entry.getKey(),
entry.getValue().getPlanFragment().getFragmentRoot());
}
_queryStats = queryStats;
}
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
index 38d617a255..bd596b116d 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
@@ -20,6 +20,7 @@ package org.apache.pinot.query.service.dispatch;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.grpc.ConnectivityState;
@@ -44,6 +45,7 @@ import java.util.function.BiConsumer;
import java.util.function.Consumer;
import javax.annotation.Nullable;
import org.apache.calcite.runtime.PairList;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.datablock.DataBlock;
import org.apache.pinot.common.exception.QueryException;
@@ -158,7 +160,7 @@ public class QueryDispatcher {
long requestId = context.getRequestId();
List<PlanNode> planNodes = new ArrayList<>();
- List<DispatchablePlanFragment> plans = Collections.singletonList(fragment);
+ Set<DispatchablePlanFragment> plans = Collections.singleton(fragment);
Set<QueryServerInstance> servers = new HashSet<>();
try {
SendRequest<List<Worker.ExplainResponse>> requestSender =
DispatchClient::explain;
@@ -199,8 +201,7 @@ public class QueryDispatcher {
Map<String, String> queryOptions)
throws Exception {
SendRequest<Worker.QueryResponse> requestSender = DispatchClient::submit;
- List<DispatchablePlanFragment> stagePlans =
dispatchableSubPlan.getQueryStageList();
- List<DispatchablePlanFragment> plansWithoutRoot = stagePlans.subList(1,
stagePlans.size());
+ Set<DispatchablePlanFragment> plansWithoutRoot =
dispatchableSubPlan.getQueryStagesWithoutRoot();
execute(requestId, plansWithoutRoot, timeoutMs, queryOptions,
requestSender, serversOut,
(response, serverInstance) -> {
if
(response.containsMetadata(CommonConstants.Query.Response.ServerResponseStatus.STATUS_ERROR))
{
@@ -244,7 +245,7 @@ public class QueryDispatcher {
return _serversByQuery != null;
}
- private <E> void execute(long requestId, List<DispatchablePlanFragment>
stagePlans,
+ private <E> void execute(long requestId, Set<DispatchablePlanFragment>
stagePlans,
long timeoutMs, Map<String, String> queryOptions,
SendRequest<E> sendRequest, Set<QueryServerInstance> serverInstancesOut,
BiConsumer<E, QueryServerInstance> resultConsumer)
@@ -252,7 +253,8 @@ public class QueryDispatcher {
Deadline deadline = Deadline.after(timeoutMs, TimeUnit.MILLISECONDS);
- List<StageInfo> stageInfos = serializePlanFragments(stagePlans,
serverInstancesOut, deadline);
+ Map<DispatchablePlanFragment, StageInfo> stageInfos =
+ serializePlanFragments(stagePlans, serverInstancesOut, deadline);
if (serverInstancesOut.isEmpty()) {
throw new RuntimeException("No server instances to dispatch query to");
@@ -272,8 +274,7 @@ public class QueryDispatcher {
serverInstance);
}
};
- Worker.QueryRequest requestBuilder =
- createRequest(serverInstance, stagePlans, stageInfos,
protoRequestMetadata);
+ Worker.QueryRequest requestBuilder = createRequest(serverInstance,
stageInfos, protoRequestMetadata);
DispatchClient dispatchClient =
getOrCreateDispatchClient(serverInstance);
try {
@@ -333,11 +334,12 @@ public class QueryDispatcher {
}
private static Worker.QueryRequest createRequest(QueryServerInstance
serverInstance,
- List<DispatchablePlanFragment> stagePlans, List<StageInfo> stageInfos,
ByteString protoRequestMetadata) {
+ Map<DispatchablePlanFragment, StageInfo> stageInfos, ByteString
protoRequestMetadata) {
Worker.QueryRequest.Builder requestBuilder =
Worker.QueryRequest.newBuilder();
requestBuilder.setVersion(CommonConstants.MultiStageQueryRunner.PlanVersions.V1);
- for (int i = 0; i < stagePlans.size(); i++) {
- DispatchablePlanFragment stagePlan = stagePlans.get(i);
+
+ for (Map.Entry<DispatchablePlanFragment, StageInfo> entry :
stageInfos.entrySet()) {
+ DispatchablePlanFragment stagePlan = entry.getKey();
List<Integer> workerIds =
stagePlan.getServerInstanceToWorkerIdMap().get(serverInstance);
if (workerIds != null) { // otherwise this server doesn't need to
execute this stage
List<WorkerMetadata> stageWorkerMetadataList =
stagePlan.getWorkerMetadataList();
@@ -347,21 +349,18 @@ public class QueryDispatcher {
}
List<Worker.WorkerMetadata> protoWorkerMetadataList =
QueryPlanSerDeUtils.toProtoWorkerMetadataList(workerMetadataList);
- StageInfo stageInfo = stageInfos.get(i);
+ StageInfo stageInfo = entry.getValue();
- //@formatter:off
Worker.StagePlan requestStagePlan = Worker.StagePlan.newBuilder()
.setRootNode(stageInfo._rootNode)
.setStageMetadata(
Worker.StageMetadata.newBuilder()
- // this is a leak from submitAndReduce (id may be
different in explain), but it's fine for now
- .setStageId(i + 1)
+ .setStageId(stagePlan.getPlanFragment().getFragmentId())
.addAllWorkerMetadata(protoWorkerMetadataList)
.setCustomProperty(stageInfo._customProperty)
.build()
)
.build();
- //@formatter:on
requestBuilder.addStagePlan(requestStagePlan);
}
}
@@ -379,18 +378,22 @@ public class QueryDispatcher {
return requestMetadata;
}
- private List<StageInfo>
serializePlanFragments(List<DispatchablePlanFragment> stagePlans,
+ private Map<DispatchablePlanFragment, StageInfo> serializePlanFragments(
+ Set<DispatchablePlanFragment> stagePlans,
Set<QueryServerInstance> serverInstances, Deadline deadline)
throws InterruptedException, ExecutionException, TimeoutException {
- List<CompletableFuture<StageInfo>> stageInfoFutures = new
ArrayList<>(stagePlans.size());
+ List<CompletableFuture<Pair<DispatchablePlanFragment, StageInfo>>>
stageInfoFutures =
+ new ArrayList<>(stagePlans.size());
for (DispatchablePlanFragment stagePlan : stagePlans) {
serverInstances.addAll(stagePlan.getServerInstanceToWorkerIdMap().keySet());
- stageInfoFutures.add(CompletableFuture.supplyAsync(() ->
serializePlanFragment(stagePlan), _executorService));
+ stageInfoFutures.add(
+ CompletableFuture.supplyAsync(() -> Pair.of(stagePlan,
serializePlanFragment(stagePlan)), _executorService));
}
- List<StageInfo> stageInfos = new ArrayList<>(stagePlans.size());
+ Map<DispatchablePlanFragment, StageInfo> stageInfos =
Maps.newHashMapWithExpectedSize(stagePlans.size());
try {
- for (CompletableFuture<StageInfo> future : stageInfoFutures) {
-
stageInfos.add(future.get(deadline.timeRemaining(TimeUnit.MILLISECONDS),
TimeUnit.MILLISECONDS));
+ for (CompletableFuture<Pair<DispatchablePlanFragment, StageInfo>> future
: stageInfoFutures) {
+ Pair<DispatchablePlanFragment, StageInfo> pair = future.get();
+ stageInfos.put(pair.getKey(), pair.getValue());
}
} finally {
for (CompletableFuture<?> future : stageInfoFutures) {
@@ -469,7 +472,7 @@ public class QueryDispatcher {
long startTimeMs = System.currentTimeMillis();
long deadlineMs = startTimeMs + timeoutMs;
// NOTE: Reduce stage is always stage 0
- DispatchablePlanFragment stagePlan = subPlan.getQueryStageList().get(0);
+ DispatchablePlanFragment stagePlan = subPlan.getQueryStageMap().get(0);
PlanFragment planFragment = stagePlan.getPlanFragment();
PlanNode rootNode = planFragment.getFragmentRoot();
diff --git
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTestBase.java
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTestBase.java
index a00b6bbcda..b40192e4e7 100644
---
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTestBase.java
+++
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTestBase.java
@@ -38,6 +38,7 @@ import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
@@ -131,13 +132,12 @@ public abstract class QueryRunnerTestBase extends
QueryTestSet {
}
// Submission Stub logic are mimic {@link QueryServer}
- List<DispatchablePlanFragment> stagePlans =
dispatchableSubPlan.getQueryStageList();
+ Set<DispatchablePlanFragment> stagePlans =
dispatchableSubPlan.getQueryStagesWithoutRoot();
List<CompletableFuture<?>> submissionStubs = new ArrayList<>();
- for (int stageId = 0; stageId < stagePlans.size(); stageId++) {
- if (stageId != 0) {
-
submissionStubs.addAll(processDistributedStagePlans(dispatchableSubPlan,
requestId, stageId,
- requestMetadataMap));
- }
+ for (DispatchablePlanFragment stagePlan : stagePlans) {
+ int stageId = stagePlan.getPlanFragment().getFragmentId();
+ submissionStubs.addAll(processDistributedStagePlans(dispatchableSubPlan,
requestId, stageId,
+ requestMetadataMap));
}
try {
CompletableFuture.allOf(submissionStubs.toArray(new
CompletableFuture[0])).get(timeoutMs, TimeUnit.MILLISECONDS);
@@ -159,7 +159,7 @@ public abstract class QueryRunnerTestBase extends
QueryTestSet {
protected List<CompletableFuture<?>>
processDistributedStagePlans(DispatchableSubPlan dispatchableSubPlan,
long requestId, int stageId, Map<String, String> requestMetadataMap) {
- DispatchablePlanFragment dispatchableStagePlan =
dispatchableSubPlan.getQueryStageList().get(stageId);
+ DispatchablePlanFragment dispatchableStagePlan =
dispatchableSubPlan.getQueryStageMap().get(stageId);
List<WorkerMetadata> stageWorkerMetadataList =
dispatchableStagePlan.getWorkerMetadataList();
List<CompletableFuture<?>> submissionStubs = new ArrayList<>();
for (Map.Entry<QueryServerInstance, List<Integer>> entry :
dispatchableStagePlan.getServerInstanceToWorkerIdMap()
diff --git
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
index f94d85c92b..7d78bd6f94 100644
---
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
+++
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
@@ -287,7 +287,7 @@ public class ResourceBasedQueriesTest extends
QueryRunnerTestBase {
private Map<String, JsonNode> tableToStats(String sql,
QueryDispatcher.QueryResult queryResult) {
- List<DispatchablePlanFragment> planNodes =
planQuery(sql).getQueryPlan().getQueryStageList();
+ Map<Integer, DispatchablePlanFragment> planNodes =
planQuery(sql).getQueryPlan().getQueryStageMap();
MultiStageStatsTreeBuilder multiStageStatsTreeBuilder =
new MultiStageStatsTreeBuilder(planNodes, queryResult.getQueryStats());
diff --git
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
index 7a14a2a4c6..fed4fbd002 100644
---
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
+++
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
@@ -28,6 +28,7 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
+import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
import org.apache.pinot.common.proto.Plan;
@@ -119,10 +120,10 @@ public class QueryServerTest extends QueryTestSet {
public void testWorkerAcceptsWorkerRequestCorrect(String sql)
throws Exception {
DispatchableSubPlan queryPlan = _queryEnvironment.planQuery(sql);
- List<DispatchablePlanFragment> stagePlans = queryPlan.getQueryStageList();
- int numStages = stagePlans.size();
+ Set<DispatchablePlanFragment> stagePlans =
queryPlan.getQueryStagesWithoutRoot();
// Ignore reduce stage (stage 0)
- for (int stageId = 1; stageId < numStages; stageId++) {
+ for (DispatchablePlanFragment stagePlan : stagePlans) {
+ int stageId = stagePlan.getPlanFragment().getFragmentId();
// only get one worker request out.
Worker.QueryRequest queryRequest = getQueryRequest(queryPlan, stageId);
Map<String, String> requestMetadata =
QueryPlanSerDeUtils.fromProtoProperties(queryRequest.getMetadata());
@@ -131,10 +132,8 @@ public class QueryServerTest extends QueryTestSet {
Worker.QueryResponse resp = submitRequest(queryRequest, requestMetadata);
assertTrue(resp.getMetadataMap().containsKey(CommonConstants.Query.Response.ServerResponseStatus.STATUS_OK));
- DispatchablePlanFragment dispatchableStagePlan = stagePlans.get(stageId);
- List<WorkerMetadata> workerMetadataList =
dispatchableStagePlan.getWorkerMetadataList();
- StageMetadata stageMetadata =
- new StageMetadata(stageId, workerMetadataList,
dispatchableStagePlan.getCustomProperties());
+ List<WorkerMetadata> workerMetadataList =
stagePlan.getWorkerMetadataList();
+ StageMetadata stageMetadata = new StageMetadata(stageId,
workerMetadataList, stagePlan.getCustomProperties());
// ensure mock query runner received correctly deserialized payload.
QueryRunner mockRunner =
_queryRunnerMap.get(Integer.parseInt(requestMetadata.get(KEY_OF_SERVER_INSTANCE_PORT)));
@@ -143,10 +142,10 @@ public class QueryServerTest extends QueryTestSet {
// since submitRequest is async, we need to wait for the mockRunner to
receive the query payload.
TestUtils.waitForCondition(aVoid -> {
try {
- verify(mockRunner,
times(workerMetadataList.size())).processQuery(any(), argThat(stagePlan -> {
- PlanNode planNode =
dispatchableStagePlan.getPlanFragment().getFragmentRoot();
- return planNode.equals(stagePlan.getRootNode()) &&
isStageMetadataEqual(stageMetadata,
- stagePlan.getStageMetadata());
+ verify(mockRunner,
times(workerMetadataList.size())).processQuery(any(), argThat(stagePlanArg -> {
+ PlanNode planNode = stagePlan.getPlanFragment().getFragmentRoot();
+ return planNode.equals(stagePlanArg.getRootNode()) &&
isStageMetadataEqual(stageMetadata,
+ stagePlanArg.getStageMetadata());
}), argThat(requestMetadataMap -> requestId.equals(
requestMetadataMap.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID))),
any());
return true;
@@ -206,7 +205,7 @@ public class QueryServerTest extends QueryTestSet {
}
private Worker.QueryRequest getQueryRequest(DispatchableSubPlan queryPlan,
int stageId) {
- DispatchablePlanFragment stagePlan =
queryPlan.getQueryStageList().get(stageId);
+ DispatchablePlanFragment stagePlan =
queryPlan.getQueryStageMap().get(stageId);
Plan.PlanNode rootNode =
PlanNodeSerializer.process(stagePlan.getPlanFragment().getFragmentRoot());
List<Worker.WorkerMetadata> workerMetadataList =
QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan.getWorkerMetadataList());
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]