This is an automated email from the ASF dual-hosted git repository.

yashmayya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 1ce48fb2e57 Fix non-deterministic hash-distributed exchange routing in 
multi-stage query engine (#17323)
1ce48fb2e57 is described below

commit 1ce48fb2e574b748ec18d5a99fbe02ac4267145b
Author: Yash Mayya <[email protected]>
AuthorDate: Mon Dec 8 15:03:35 2025 -0800

    Fix non-deterministic hash-distributed exchange routing in multi-stage 
query engine (#17323)
---
 .../planner/physical/MailboxAssignmentVisitor.java |   7 ++
 .../apache/pinot/query/routing/WorkerManager.java  |   1 +
 .../physical/MailboxAssignmentVisitorTest.java     |  99 ++++++++++++++++++
 .../resources/queries/ExplainPhysicalPlans.json    |  60 +++++++++++
 .../src/test/resources/queries/Joins.json          | 112 +++++++++++++++++++++
 5 files changed, 279 insertions(+)

diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/MailboxAssignmentVisitor.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/MailboxAssignmentVisitor.java
index 758b72f2222..233874017ae 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/MailboxAssignmentVisitor.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/MailboxAssignmentVisitor.java
@@ -20,6 +20,7 @@ package org.apache.pinot.query.planner.physical;
 
 import com.google.common.base.Preconditions;
 import java.util.ArrayList;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -187,6 +188,12 @@ public class MailboxAssignmentVisitor extends 
DefaultPostOrderTraversalVisitor<V
       List<Integer> workerIds = entry.getValue();
       mailboxInfoList.add(new MailboxInfo(server.getHostname(), 
server.getQueryMailboxPort(), workerIds));
     }
+    // Sort by first workerId to ensure deterministic ordering.
+    // This is critical for hash-distributed exchanges where (hash % 
numMailboxes) is used as an index to choose the
+    // receiving worker.
+    // Without this sorting, different stages could route the same hash value 
to different workers, resulting in
+    // incorrect join/union/intersect results for pre-partitioned sends (where 
a full partition shuffle is skipped).
+    mailboxInfoList.sort(Comparator.comparingInt(info -> 
info.getWorkerIds().get(0)));
     MailboxInfos mailboxInfos =
         numWorkers > 1 ? new SharedMailboxInfos(mailboxInfoList) : new 
MailboxInfos(mailboxInfoList);
     for (int workerId = 0; workerId < numWorkers; workerId++) {
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
index b4edb4414c8..3d379896d75 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
@@ -114,6 +114,7 @@ public class WorkerManager {
     }
   }
 
+  // TODO: Ensure that workerId to server assignment is deterministic across 
all stages in a query.
   private void assignWorkersToNonRootFragment(PlanFragment fragment, 
DispatchablePlanContext context) {
     List<PlanFragment> children = fragment.getChildren();
     for (PlanFragment child : children) {
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/MailboxAssignmentVisitorTest.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/MailboxAssignmentVisitorTest.java
new file mode 100644
index 00000000000..9df5eae9cac
--- /dev/null
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/MailboxAssignmentVisitorTest.java
@@ -0,0 +1,99 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.physical;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.query.QueryEnvironmentTestBase;
+import org.apache.pinot.query.routing.MailboxInfo;
+import org.apache.pinot.query.routing.MailboxInfos;
+import org.apache.pinot.query.routing.WorkerMetadata;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertTrue;
+
+
+/**
+ * Tests for mailbox assignment determinism in {@link 
MailboxAssignmentVisitor}.
+ *
+ * These tests verify that the mailbox info list is sorted by worker ID to 
ensure
+ * deterministic hash-distributed exchange routing. This is critical for 
correct
+ * join results when HashExchange uses (hash % numMailboxes) as an index.
+ */
+public class MailboxAssignmentVisitorTest extends QueryEnvironmentTestBase {
+
+  @Test
+  public void testVariousJoinQueriesHaveSortedMailboxes() {
+    String[] queries = {
+        // Simple join
+        "SELECT * FROM a JOIN b ON a.col1 = b.col1",
+        // Join with aggregation
+        "SELECT a.col1, COUNT(*) FROM a JOIN b ON a.col1 = b.col1 GROUP BY 
a.col1",
+        "SELECT a.col1, SUM(a.col3), SUM(b.col3) FROM a JOIN b ON a.col1 = 
b.col1 GROUP BY a.col1",
+        // Multi-way join
+        "SELECT * FROM a JOIN b ON a.col1 = b.col1 JOIN c ON b.col1 = c.col1",
+        // Join with filter
+        "SELECT * FROM a JOIN b ON a.col1 = b.col1 WHERE a.col3 > 0",
+    };
+
+    for (String query : queries) {
+      DispatchableSubPlan subPlan = _queryEnvironment.planQuery(query);
+      verifyAllMailboxInfosSorted(subPlan, query);
+    }
+  }
+
+  @Test
+  public void testUnionQueryHasSortedMailboxes() {
+    String query = "SELECT col1, SUM(col3) FROM a GROUP BY col1 "
+        + "UNION ALL "
+        + "SELECT col1, SUM(col3) FROM b GROUP BY col1";
+
+    DispatchableSubPlan subPlan = _queryEnvironment.planQuery(query);
+    verifyAllMailboxInfosSorted(subPlan, query);
+  }
+
+  private void verifyAllMailboxInfosSorted(DispatchableSubPlan subPlan, String 
query) {
+    for (DispatchablePlanFragment fragment : subPlan.getQueryStages()) {
+      List<WorkerMetadata> workerMetadataList = 
fragment.getWorkerMetadataList();
+
+      for (WorkerMetadata workerMetadata : workerMetadataList) {
+        Map<Integer, MailboxInfos> mailboxInfosMap = 
workerMetadata.getMailboxInfosMap();
+
+        for (Map.Entry<Integer, MailboxInfos> entry : 
mailboxInfosMap.entrySet()) {
+          MailboxInfos mailboxInfos = entry.getValue();
+          List<MailboxInfo> infoList = mailboxInfos.getMailboxInfos();
+
+          // Expand all worker IDs from all MailboxInfos
+          List<Integer> expandedWorkerIds = new ArrayList<>();
+          for (MailboxInfo info : infoList) {
+            expandedWorkerIds.addAll(info.getWorkerIds());
+          }
+
+          // Verify the expanded list is sorted
+          for (int i = 0; i < expandedWorkerIds.size() - 1; i++) {
+            assertTrue(expandedWorkerIds.get(i) < expandedWorkerIds.get(i + 1),
+                String.format("Expanded worker IDs not sorted: %d at index %d, 
%d at index %d",
+                    expandedWorkerIds.get(i), i, expandedWorkerIds.get(i + 1), 
i + 1));
+          }
+        }
+      }
+    }
+  }
+}
diff --git 
a/pinot-query-planner/src/test/resources/queries/ExplainPhysicalPlans.json 
b/pinot-query-planner/src/test/resources/queries/ExplainPhysicalPlans.json
index 6e6227e1783..382196cb967 100644
--- a/pinot-query-planner/src/test/resources/queries/ExplainPhysicalPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/ExplainPhysicalPlans.json
@@ -647,6 +647,66 @@
           "                        └── [5]@localhost:2|[0] FILTER\n",
           "                            └── [5]@localhost:2|[0] TABLE SCAN (a) 
null\n"
         ]
+      },
+      {
+        "description": "explain plan with join after group-by on both tables - 
verifies pre-partitioned distribution when group keys match join keys",
+        "comments": "This test verifies that when both sides of a join are 
aggregated by the join key, the planner correctly marks the exchanges as 
[PARTITIONED].",
+        "sql": "EXPLAIN IMPLEMENTATION PLAN FOR SELECT tmpA.col1, tmpA.cnt, 
tmpB.total FROM (SELECT a.col1, COUNT(*) AS cnt FROM a GROUP BY a.col1) tmpA 
JOIN (SELECT b.col1, SUM(b.col3) AS total FROM b GROUP BY b.col1) tmpB ON 
tmpA.col1 = tmpB.col1",
+        "output": [
+          "[0]@localhost:3|[0] MAIL_RECEIVE(BROADCAST_DISTRIBUTED)\n",
+          "├── [1]@localhost:1|[1] 
MAIL_SEND(BROADCAST_DISTRIBUTED)->{[0]@localhost:3|[0]} (Subtree Omitted)\n",
+          "└── [1]@localhost:2|[0] 
MAIL_SEND(BROADCAST_DISTRIBUTED)->{[0]@localhost:3|[0]}\n",
+          "    └── [1]@localhost:2|[0] PROJECT\n",
+          "        └── [1]@localhost:2|[0] JOIN\n",
+          "            ├── [1]@localhost:2|[0] 
MAIL_RECEIVE(HASH_DISTRIBUTED)\n",
+          "            │   ├── [2]@localhost:1|[1] 
MAIL_SEND(HASH_DISTRIBUTED)[PARTITIONED]->{[1]@localhost:1|[1]} (Subtree 
Omitted)\n",
+          "            │   └── [2]@localhost:2|[0] 
MAIL_SEND(HASH_DISTRIBUTED)[PARTITIONED]->{[1]@localhost:2|[0]}\n",
+          "            │       └── [2]@localhost:2|[0] AGGREGATE_FINAL\n",
+          "            │           └── [2]@localhost:2|[0] 
MAIL_RECEIVE(HASH_DISTRIBUTED)\n",
+          "            │               ├── [3]@localhost:1|[1] 
MAIL_SEND(HASH_DISTRIBUTED)->{[2]@localhost:1|[1],[2]@localhost:2|[0]} (Subtree 
Omitted)\n",
+          "            │               └── [3]@localhost:2|[0] 
MAIL_SEND(HASH_DISTRIBUTED)->{[2]@localhost:1|[1],[2]@localhost:2|[0]}\n",
+          "            │                   └── [3]@localhost:2|[0] 
AGGREGATE_LEAF\n",
+          "            │                       └── [3]@localhost:2|[0] TABLE 
SCAN (a) null\n",
+          "            └── [1]@localhost:2|[0] 
MAIL_RECEIVE(HASH_DISTRIBUTED)\n",
+          "                ├── [4]@localhost:1|[1] 
MAIL_SEND(HASH_DISTRIBUTED)[PARTITIONED]->{[1]@localhost:1|[1]} (Subtree 
Omitted)\n",
+          "                └── [4]@localhost:2|[0] 
MAIL_SEND(HASH_DISTRIBUTED)[PARTITIONED]->{[1]@localhost:2|[0]}\n",
+          "                    └── [4]@localhost:2|[0] AGGREGATE_FINAL\n",
+          "                        └── [4]@localhost:2|[0] 
MAIL_RECEIVE(HASH_DISTRIBUTED)\n",
+          "                            └── [5]@localhost:1|[0] 
MAIL_SEND(HASH_DISTRIBUTED)->{[4]@localhost:1|[1],[4]@localhost:2|[0]}\n",
+          "                                └── [5]@localhost:1|[0] 
AGGREGATE_LEAF\n",
+          "                                    └── [5]@localhost:1|[0] TABLE 
SCAN (b) null\n",
+          ""
+        ]
+      },
+      {
+        "description": "explain plan with join after group-by - both sides 
partitioned when group keys match respective join keys",
+        "comments": "Both sides are [PARTITIONED] because left groups by col2 
and joins on col2, right groups by col1 and joins on col1",
+        "sql": "EXPLAIN IMPLEMENTATION PLAN FOR SELECT tmpA.col2, tmpA.cnt, 
tmpB.total FROM (SELECT a.col2, COUNT(*) AS cnt FROM a GROUP BY a.col2) tmpA 
JOIN (SELECT b.col1, SUM(b.col3) AS total FROM b GROUP BY b.col1) tmpB ON 
tmpA.col2 = tmpB.col1",
+        "output": [
+          "[0]@localhost:3|[0] MAIL_RECEIVE(BROADCAST_DISTRIBUTED)\n",
+          "├── [1]@localhost:1|[1] 
MAIL_SEND(BROADCAST_DISTRIBUTED)->{[0]@localhost:3|[0]} (Subtree Omitted)\n",
+          "└── [1]@localhost:2|[0] 
MAIL_SEND(BROADCAST_DISTRIBUTED)->{[0]@localhost:3|[0]}\n",
+          "    └── [1]@localhost:2|[0] PROJECT\n",
+          "        └── [1]@localhost:2|[0] JOIN\n",
+          "            ├── [1]@localhost:2|[0] 
MAIL_RECEIVE(HASH_DISTRIBUTED)\n",
+          "            │   ├── [2]@localhost:1|[1] 
MAIL_SEND(HASH_DISTRIBUTED)[PARTITIONED]->{[1]@localhost:1|[1]} (Subtree 
Omitted)\n",
+          "            │   └── [2]@localhost:2|[0] 
MAIL_SEND(HASH_DISTRIBUTED)[PARTITIONED]->{[1]@localhost:2|[0]}\n",
+          "            │       └── [2]@localhost:2|[0] AGGREGATE_FINAL\n",
+          "            │           └── [2]@localhost:2|[0] 
MAIL_RECEIVE(HASH_DISTRIBUTED)\n",
+          "            │               ├── [3]@localhost:1|[1] 
MAIL_SEND(HASH_DISTRIBUTED)->{[2]@localhost:1|[1],[2]@localhost:2|[0]} (Subtree 
Omitted)\n",
+          "            │               └── [3]@localhost:2|[0] 
MAIL_SEND(HASH_DISTRIBUTED)->{[2]@localhost:1|[1],[2]@localhost:2|[0]}\n",
+          "            │                   └── [3]@localhost:2|[0] 
AGGREGATE_LEAF\n",
+          "            │                       └── [3]@localhost:2|[0] TABLE 
SCAN (a) null\n",
+          "            └── [1]@localhost:2|[0] 
MAIL_RECEIVE(HASH_DISTRIBUTED)\n",
+          "                ├── [4]@localhost:1|[1] 
MAIL_SEND(HASH_DISTRIBUTED)[PARTITIONED]->{[1]@localhost:1|[1]} (Subtree 
Omitted)\n",
+          "                └── [4]@localhost:2|[0] 
MAIL_SEND(HASH_DISTRIBUTED)[PARTITIONED]->{[1]@localhost:2|[0]}\n",
+          "                    └── [4]@localhost:2|[0] AGGREGATE_FINAL\n",
+          "                        └── [4]@localhost:2|[0] 
MAIL_RECEIVE(HASH_DISTRIBUTED)\n",
+          "                            └── [5]@localhost:1|[0] 
MAIL_SEND(HASH_DISTRIBUTED)->{[4]@localhost:1|[1],[4]@localhost:2|[0]}\n",
+          "                                └── [5]@localhost:1|[0] 
AGGREGATE_LEAF\n",
+          "                                    └── [5]@localhost:1|[0] TABLE 
SCAN (b) null\n",
+          ""
+        ]
       }
     ]
   }
diff --git a/pinot-query-runtime/src/test/resources/queries/Joins.json 
b/pinot-query-runtime/src/test/resources/queries/Joins.json
new file mode 100644
index 00000000000..50ecf1cc976
--- /dev/null
+++ b/pinot-query-runtime/src/test/resources/queries/Joins.json
@@ -0,0 +1,112 @@
+{
+  "hash_distributed_join_with_aggregates": {
+    "comment": "Tests join correctness when both inputs are already 
hash-distributed on the join keys via aggregation thus resulting in a 
pre-partitioned distribution.",
+    "tables": {
+      "left_tbl": {
+        "schema": [
+          {"name": "key_col", "type": "STRING"},
+          {"name": "value_col", "type": "INT"}
+        ],
+        "inputs": [
+          ["a", 1],
+          ["a", 2],
+          ["b", 3],
+          ["b", 4],
+          ["c", 5],
+          ["c", 6],
+          ["d", 7],
+          ["d", 8],
+          ["e", 9],
+          ["e", 10]
+        ]
+      },
+      "right_tbl": {
+        "schema": [
+          {"name": "key_col", "type": "STRING"},
+          {"name": "metric_col", "type": "INT"}
+        ],
+        "inputs": [
+          ["a", 100],
+          ["a", 200],
+          ["b", 300],
+          ["b", 400],
+          ["c", 500],
+          ["c", 600],
+          ["d", 700],
+          ["d", 800],
+          ["e", 900],
+          ["e", 1000]
+        ]
+      }
+    },
+    "queries": [
+      {
+        "description": "Join with aggregation on both sides - exercises 
hash-distributed exchange routing",
+        "sql": "SELECT l.key_col, l.sum_val, r.sum_metric FROM (SELECT 
key_col, SUM(value_col) AS sum_val FROM {left_tbl} GROUP BY key_col) l JOIN 
(SELECT key_col, SUM(metric_col) AS sum_metric FROM {right_tbl} GROUP BY 
key_col) r ON l.key_col = r.key_col ORDER BY l.key_col"
+      },
+      {
+        "description": "Count joined rows with aggregated inputs",
+        "sql": "SELECT COUNT(*) FROM (SELECT key_col, SUM(value_col) AS 
sum_val FROM {left_tbl} GROUP BY key_col) l JOIN (SELECT key_col, 
SUM(metric_col) AS sum_metric FROM {right_tbl} GROUP BY key_col) r ON l.key_col 
= r.key_col"
+      },
+      {
+        "description": "Sum of joined values",
+        "sql": "SELECT SUM(l.sum_val), SUM(r.sum_metric) FROM (SELECT key_col, 
SUM(value_col) AS sum_val FROM {left_tbl} GROUP BY key_col) l JOIN (SELECT 
key_col, SUM(metric_col) AS sum_metric FROM {right_tbl} GROUP BY key_col) r ON 
l.key_col = r.key_col"
+      },
+      {
+        "description": "Direct join with same group keys as join keys",
+        "sql": "SELECT a.key_col, SUM(a.value_col), SUM(b.metric_col) FROM 
{left_tbl} a JOIN {right_tbl} b ON a.key_col = b.key_col GROUP BY a.key_col 
ORDER BY a.key_col"
+      }
+    ]
+  },
+  "multi_table_join_chain": {
+    "comment": "Tests join chains that create multiple hash-distributed 
exchanges",
+    "tables": {
+      "orders": {
+        "schema": [
+          {"name": "order_id", "type": "INT"},
+          {"name": "customer_id", "type": "STRING"},
+          {"name": "amount", "type": "INT"}
+        ],
+        "inputs": [
+          [1, "c1", 100],
+          [2, "c1", 200],
+          [3, "c2", 150],
+          [4, "c2", 250],
+          [5, "c3", 300]
+        ]
+      },
+      "customers": {
+        "schema": [
+          {"name": "customer_id", "type": "STRING"},
+          {"name": "name", "type": "STRING"}
+        ],
+        "inputs": [
+          ["c1", "Alice"],
+          ["c2", "Bob"],
+          ["c3", "Charlie"]
+        ]
+      },
+      "regions": {
+        "schema": [
+          {"name": "customer_id", "type": "STRING"},
+          {"name": "region", "type": "STRING"}
+        ],
+        "inputs": [
+          ["c1", "North"],
+          ["c2", "South"],
+          ["c3", "East"]
+        ]
+      }
+    },
+    "queries": [
+      {
+        "description": "Three-way join with aggregation",
+        "sql": "SELECT c.name, r.region, SUM(o.amount) as total FROM {orders} 
o JOIN {customers} c ON o.customer_id = c.customer_id JOIN {regions} r ON 
c.customer_id = r.customer_id GROUP BY c.name, r.region ORDER BY c.name"
+      },
+      {
+        "description": "Join with pre-aggregated subquery",
+        "sql": "SELECT c.name, o.total_amount FROM {customers} c JOIN (SELECT 
customer_id, SUM(amount) as total_amount FROM {orders} GROUP BY customer_id) o 
ON c.customer_id = o.customer_id ORDER BY c.name"
+      }
+    ]
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to