This is an automated email from the ASF dual-hosted git repository.

gortiz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new cf720e3c2d [Spool] Introduce stage replacer and change send nodes to 
be able to send to more than one stage (#14495)
cf720e3c2d is described below

commit cf720e3c2d5f3c3bfed132724f5c5f57292c2c2d
Author: Gonzalo Ortiz Jaureguizar <[email protected]>
AuthorDate: Thu Nov 21 17:34:02 2024 +0100

    [Spool] Introduce stage replacer and change send nodes to be able to send 
to more than one stage (#14495)
---
 .../planner/logical/EquivalentStagesFinder.java    |  16 +-
 .../planner/logical/EquivalentStagesReplacer.java  |  79 ++++++
 .../pinot/query/planner/logical/GroupedStages.java |   8 +
 .../query/planner/plannode/MailboxReceiveNode.java |  11 +-
 .../query/planner/plannode/MailboxSendNode.java    |  94 +++++++-
 .../query/planner/plannode/PlanNodeVisitor.java    |  54 +++--
 .../logical/EquivalentStagesFinderTest.java        |  47 +++-
 .../logical/EquivalentStagesReplacerTest.java      | 146 ++++++++++++
 .../query/planner/logical/StagesTestBase.java      | 265 ++++++++++++++++++++-
 9 files changed, 671 insertions(+), 49 deletions(-)

diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java
index a5c98eb54c..28bca306cd 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java
@@ -120,11 +120,15 @@ public class EquivalentStagesFinder {
         return areBaseNodesEquivalent(stage, visitedStage)
             // Commented out fields are used in equals() method of 
MailboxSendNode but not needed for equivalence.
             // Receiver stage is not important for equivalence
-//            && node1.getReceiverStageId() == that.getReceiverStageId()
+//            && stage.getReceiverStageId() == 
visitedStage.getReceiverStageId()
             && stage.getExchangeType() == visitedStage.getExchangeType()
-            // Distribution type is not needed for equivalence. We deal with 
difference distribution types in the
-            // spooling logic.
-//            && Objects.equals(node1.getDistributionType(), 
that.getDistributionType())
+            // TODO: Distribution type not needed for equivalence in the first 
substituted send nodes. Their different
+            //  distribution can be implemented in synthetic stages. But it is 
important in recursive send nodes
+            //  (a send node that is equivalent to another but where both of 
them send to stages that are also
+            //  equivalent).
+            //  This makes the equivalence check more complex and therefore we 
are going to consider the distribution
+            //  type in the equivalence check.
+            && Objects.equals(stage.getDistributionType(), 
visitedStage.getDistributionType())
             // TODO: Keys could probably be removed from the equivalence 
check, but would require to verify both
             //  keys are present in the data schema. We are not doing that for 
now.
             && Objects.equals(stage.getKeys(), visitedStage.getKeys())
@@ -220,9 +224,7 @@ public class EquivalentStagesFinder {
             // TODO: Keys should probably be removed from the equivalence 
check, but would require to verify both
             //  keys are present in the data schema. We are not doing that for 
now.
             && Objects.equals(node1.getKeys(), that.getKeys())
-            // Distribution type is not needed for equivalence. We deal with 
difference distribution types in the
-            // spooling logic.
-//          && node1.getDistributionType() == that.getDistributionType()
+            && node1.getDistributionType() == that.getDistributionType()
             // TODO: Sort, sort on sender and collations can probably be 
removed from the equivalence check, but would
             //  require some extra checks or transformation on the spooling 
logic. We are not doing that for now.
             && node1.isSort() == that.isSort()
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesReplacer.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesReplacer.java
new file mode 100644
index 0000000000..06a4cf16da
--- /dev/null
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesReplacer.java
@@ -0,0 +1,79 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.logical;
+
+import org.apache.pinot.query.planner.plannode.MailboxReceiveNode;
+import org.apache.pinot.query.planner.plannode.MailboxSendNode;
+import org.apache.pinot.query.planner.plannode.PlanNode;
+import org.apache.pinot.query.planner.plannode.PlanNodeVisitor;
+
+
+/**
+ * EquivalentStageReplacer is used to replace equivalent stages in the query 
plan.
+ *
+ * Given a {@link org.apache.pinot.query.planner.plannode.PlanNode} and a
+ * {@link GroupedStages}, modifies the plan node to replace equivalent stages.
+ *
+ * For each {@link MailboxReceiveNode} in the plan, if the sender is not the 
leader of the group,
+ * replaces the sender with the leader.
+ * The leader is also updated to include the receiver in its list of receivers.
+ */
+public class EquivalentStagesReplacer {
+  private EquivalentStagesReplacer() {
+  }
+
+  /**
+   * Replaces the equivalent stages in the query plan.
+   *
+   * @param root Root plan node
+   * @param equivalentStages Equivalent stages
+   */
+  public static void replaceEquivalentStages(PlanNode root, GroupedStages 
equivalentStages) {
+    root.visit(Replacer.INSTANCE, equivalentStages);
+  }
+
+  private static class Replacer extends 
PlanNodeVisitor.DepthFirstVisitor<Void, GroupedStages> {
+    private static final Replacer INSTANCE = new Replacer();
+
+    private Replacer() {
+    }
+
+    @Override
+    public Void visitMailboxReceive(MailboxReceiveNode node, GroupedStages 
equivalenceGroups) {
+      MailboxSendNode sender = node.getSender();
+      MailboxSendNode leader = equivalenceGroups.getGroup(sender).first();
+      if (canSubstitute(sender, leader)) {
+        // we don't want to visit the children of the node given it is going 
to be pruned
+        node.setSender(leader);
+        leader.addReceiver(node);
+      } else {
+        visitMailboxSend(leader, equivalenceGroups);
+      }
+      return null;
+    }
+
+    private boolean canSubstitute(MailboxSendNode actualSender, 
MailboxSendNode leader) {
+      return actualSender != leader // we don't need to replace the leader 
with itself
+          // the leader is already sending to this stage. Given we don't have 
the ability to send to multiple
+          // receivers in the same stage, we cannot optimize this case right 
now.
+          // If this case seems to be useful, it can be supported in the 
future.
+          && !leader.sharesReceiverStages(actualSender);
+    }
+  }
+}
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/GroupedStages.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/GroupedStages.java
index 45b5b561f9..823a9e9832 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/GroupedStages.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/GroupedStages.java
@@ -22,6 +22,7 @@ import com.google.common.base.Preconditions;
 import java.util.Comparator;
 import java.util.IdentityHashMap;
 import java.util.NoSuchElementException;
+import java.util.Set;
 import java.util.SortedSet;
 import java.util.TreeSet;
 import java.util.stream.Collectors;
@@ -77,6 +78,8 @@ public abstract class GroupedStages {
    */
   public abstract SortedSet<SortedSet<MailboxSendNode>> getGroups();
 
+  public abstract Set<MailboxSendNode> getStages();
+
   @Override
   public String toString() {
     String content = getGroups().stream()
@@ -154,6 +157,11 @@ public abstract class GroupedStages {
       return _stageToGroup.containsKey(stage);
     }
 
+    @Override
+    public Set<MailboxSendNode> getStages() {
+      return _stageToGroup.keySet();
+    }
+
     @Override
     public SortedSet<MailboxSendNode> getGroup(MailboxSendNode stage)
         throws NoSuchElementException {
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxReceiveNode.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxReceiveNode.java
index c918e9ea91..407941e6b4 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxReceiveNode.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxReceiveNode.java
@@ -29,7 +29,7 @@ import org.apache.pinot.common.utils.DataSchema;
 
 
 public class MailboxReceiveNode extends BasePlanNode {
-  private final int _senderStageId;
+  private int _senderStageId;
   private final PinotRelExchangeType _exchangeType;
   private RelDistribution.Type _distributionType;
   private final List<Integer> _keys;
@@ -38,7 +38,7 @@ public class MailboxReceiveNode extends BasePlanNode {
   private final boolean _sortedOnSender;
 
   // NOTE: This is only available during query planning, and should not be 
serialized.
-  private final transient MailboxSendNode _sender;
+  private transient MailboxSendNode _sender;
 
   // NOTE: null List is converted to empty List because there is no way to 
differentiate them in proto during ser/de.
   public MailboxReceiveNode(int stageId, DataSchema dataSchema, int 
senderStageId,
@@ -57,6 +57,8 @@ public class MailboxReceiveNode extends BasePlanNode {
   }
 
   public int getSenderStageId() {
+    assert _sender == null || _sender.getStageId() == _senderStageId
+        : "_senderStageId should match _sender.getStageId()";
     return _senderStageId;
   }
 
@@ -93,6 +95,11 @@ public class MailboxReceiveNode extends BasePlanNode {
     return _sender;
   }
 
+  public void setSender(MailboxSendNode sender) {
+    _senderStageId = sender.getStageId();
+    _sender = sender;
+  }
+
   @Override
   public String explain() {
     return "MAIL_RECEIVE(" + _distributionType + ")";
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxSendNode.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxSendNode.java
index b4aa8677e2..9cc2c2e657 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxSendNode.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxSendNode.java
@@ -18,6 +18,9 @@
  */
 package org.apache.pinot.query.planner.plannode;
 
+import com.google.common.base.Preconditions;
+import java.util.BitSet;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Objects;
 import javax.annotation.Nullable;
@@ -28,7 +31,7 @@ import org.apache.pinot.common.utils.DataSchema;
 
 
 public class MailboxSendNode extends BasePlanNode {
-  private final int _receiverStageId;
+  private final BitSet _receiverStages;
   private final PinotRelExchangeType _exchangeType;
   private RelDistribution.Type _distributionType;
   private final List<Integer> _keys;
@@ -37,11 +40,12 @@ public class MailboxSendNode extends BasePlanNode {
   private final boolean _sort;
 
   // NOTE: null List is converted to empty List because there is no way to 
differentiate them in proto during ser/de.
-  public MailboxSendNode(int stageId, DataSchema dataSchema, List<PlanNode> 
inputs, int receiverStageId,
-      PinotRelExchangeType exchangeType, RelDistribution.Type 
distributionType, @Nullable List<Integer> keys,
-      boolean prePartitioned, @Nullable List<RelFieldCollation> collations, 
boolean sort) {
+  private MailboxSendNode(int stageId, DataSchema dataSchema, List<PlanNode> 
inputs,
+      BitSet receiverStages, PinotRelExchangeType exchangeType,
+      RelDistribution.Type distributionType, @Nullable List<Integer> keys, 
boolean prePartitioned,
+      @Nullable List<RelFieldCollation> collations, boolean sort) {
     super(stageId, dataSchema, null, inputs);
-    _receiverStageId = receiverStageId;
+    _receiverStages = receiverStages;
     _exchangeType = exchangeType;
     _distributionType = distributionType;
     _keys = keys != null ? keys : List.of();
@@ -50,8 +54,74 @@ public class MailboxSendNode extends BasePlanNode {
     _sort = sort;
   }
 
+  public MailboxSendNode(int stageId, DataSchema dataSchema, List<PlanNode> 
inputs,
+      int receiverStage, PinotRelExchangeType exchangeType,
+      RelDistribution.Type distributionType, @Nullable List<Integer> keys, 
boolean prePartitioned,
+      @Nullable List<RelFieldCollation> collations, boolean sort) {
+    this(stageId, dataSchema, inputs, toBitSet(receiverStage), exchangeType, 
distributionType, keys, prePartitioned,
+        collations, sort);
+  }
+
+  private static BitSet toBitSet(int receiverStage) {
+    BitSet bitSet = new BitSet(receiverStage + 1);
+    bitSet.set(receiverStage);
+    return bitSet;
+  }
+
+  private static BitSet toBitSet(@Nullable List<Integer> receiverStages) {
+    BitSet bitSet = new BitSet();
+    if (receiverStages == null || receiverStages.isEmpty()) {
+      return bitSet;
+    }
+    for (int receiverStage : receiverStages) {
+      bitSet.set(receiverStage);
+    }
+    return bitSet;
+  }
+
+  public MailboxSendNode(int stageId, DataSchema dataSchema, List<PlanNode> 
inputs,
+      PinotRelExchangeType exchangeType, RelDistribution.Type 
distributionType, @Nullable List<Integer> keys,
+      boolean prePartitioned, @Nullable List<RelFieldCollation> collations, 
boolean sort) {
+    this(stageId, dataSchema, inputs, new BitSet(), exchangeType, 
distributionType, keys, prePartitioned, collations,
+        sort);
+  }
+
+  public boolean sharesReceiverStages(MailboxSendNode other) {
+    return _receiverStages.intersects(other._receiverStages);
+  }
+
+  /**
+   * Returns the receiver stage ids, sorted in ascending order.
+   */
+  public Iterable<Integer> getReceiverStageIds() {
+    return () -> new Iterator<>() {
+      int _next = _receiverStages.nextSetBit(0);
+
+      @Override
+      public boolean hasNext() {
+        return _next >= 0;
+      }
+
+      @Override
+      public Integer next() {
+        int current = _next;
+        _next = _receiverStages.nextSetBit(_next + 1);
+        return current;
+      }
+    };
+  }
+
+  @Deprecated
   public int getReceiverStageId() {
-    return _receiverStageId;
+    Preconditions.checkState(!_receiverStages.isEmpty(), "Receivers not set");
+    return _receiverStages.nextSetBit(0);
+  }
+
+  public void addReceiver(MailboxReceiveNode node) {
+    if (_receiverStages.get(node.getStageId())) {
+      throw new IllegalStateException("Receiver already added: " + 
node.getStageId());
+    }
+    _receiverStages.set(node.getStageId());
   }
 
   public PinotRelExchangeType getExchangeType() {
@@ -104,7 +174,7 @@ public class MailboxSendNode extends BasePlanNode {
 
   @Override
   public PlanNode withInputs(List<PlanNode> inputs) {
-    return new MailboxSendNode(_stageId, _dataSchema, inputs, 
_receiverStageId, _exchangeType, _distributionType, _keys,
+    return new MailboxSendNode(_stageId, _dataSchema, inputs, _receiverStages, 
_exchangeType, _distributionType, _keys,
         _prePartitioned, _collations, _sort);
   }
 
@@ -120,14 +190,14 @@ public class MailboxSendNode extends BasePlanNode {
       return false;
     }
     MailboxSendNode that = (MailboxSendNode) o;
-    return _receiverStageId == that._receiverStageId && _prePartitioned == 
that._prePartitioned && _sort == that._sort
-        && _exchangeType == that._exchangeType && _distributionType == 
that._distributionType && Objects.equals(_keys,
-        that._keys) && Objects.equals(_collations, that._collations);
+    return Objects.equals(_receiverStages, that._receiverStages) && 
_prePartitioned == that._prePartitioned
+        && _sort == that._sort && _exchangeType == that._exchangeType && 
_distributionType == that._distributionType
+        && Objects.equals(_keys, that._keys) && Objects.equals(_collations, 
that._collations);
   }
 
   @Override
   public int hashCode() {
-    return Objects.hash(super.hashCode(), _receiverStageId, _exchangeType, 
_distributionType, _keys, _prePartitioned,
+    return Objects.hash(super.hashCode(), _receiverStages, _exchangeType, 
_distributionType, _keys, _prePartitioned,
         _collations, _sort);
   }
 
@@ -135,7 +205,7 @@ public class MailboxSendNode extends BasePlanNode {
   public String toString() {
     return "MailboxSendNode{"
         + "_stageId=" + _stageId
-        + ", _receiverStageId=" + _receiverStageId
+        + ", _receivers=" + _receiverStages
         + '}';
   }
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/PlanNodeVisitor.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/PlanNodeVisitor.java
index 0327d89e65..49494f8df6 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/PlanNodeVisitor.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/PlanNodeVisitor.java
@@ -69,7 +69,7 @@ public interface PlanNodeVisitor<T, C> {
    *
    * The default implementation for each plan node type does nothing but 
visiting its inputs
    * (see {@link #visitChildren(PlanNode, Object)}) and then returning the 
result of calling
-   * {@link #defaultCase(PlanNode, Object)}.
+   * {@link #postChildren(PlanNode, Object)}.
    *
    * Subclasses can override each method to provide custom behavior for each 
plan node type.
    * For example:
@@ -117,6 +117,17 @@ public interface PlanNodeVisitor<T, C> {
       return true;
     }
 
+    /**
+     * The method that is called by default to handle a node that does not 
have a specific visit method.
+     *
+     * This method can be overridden to provide a default behavior for all 
nodes.
+     *
+     * The returned value of this method is ignored by default
+     */
+    protected T preChildren(PlanNode node, C context) {
+      return null;
+    }
+
     /**
      * The method that is called by default to handle a node that does not 
have a specific visit method.
      *
@@ -124,89 +135,102 @@ public interface PlanNodeVisitor<T, C> {
      *
      * The returned value of this method is what each default visit method 
will return.
      */
-    protected T defaultCase(PlanNode node, C context) {
+    protected T postChildren(PlanNode node, C context) {
       return null;
     }
 
     @Override
     public T visitAggregate(AggregateNode node, C context) {
+      preChildren(node, context);
       visitChildren(node, context);
-      return defaultCase(node, context);
+      return postChildren(node, context);
     }
 
     @Override
     public T visitFilter(FilterNode node, C context) {
+      preChildren(node, context);
       visitChildren(node, context);
-      return defaultCase(node, context);
+      return postChildren(node, context);
     }
 
     @Override
     public T visitJoin(JoinNode node, C context) {
+      preChildren(node, context);
       visitChildren(node, context);
-      return defaultCase(node, context);
+      return postChildren(node, context);
     }
 
     @Override
     public T visitMailboxReceive(MailboxReceiveNode node, C context) {
+      preChildren(node, context);
       visitChildren(node, context);
       if (traverseStageBoundary()) {
         node.getSender().visit(this, context);
       }
-      return defaultCase(node, context);
+      return postChildren(node, context);
     }
 
     @Override
     public T visitMailboxSend(MailboxSendNode node, C context) {
+      preChildren(node, context);
       visitChildren(node, context);
-      return defaultCase(node, context);
+      return postChildren(node, context);
     }
 
     @Override
     public T visitProject(ProjectNode node, C context) {
+      preChildren(node, context);
       visitChildren(node, context);
-      return defaultCase(node, context);
+      return postChildren(node, context);
     }
 
     @Override
     public T visitSort(SortNode node, C context) {
+      preChildren(node, context);
       visitChildren(node, context);
-      return defaultCase(node, context);
+      return postChildren(node, context);
     }
 
     @Override
     public T visitTableScan(TableScanNode node, C context) {
+      preChildren(node, context);
       visitChildren(node, context);
-      return defaultCase(node, context);
+      return postChildren(node, context);
     }
 
     @Override
     public T visitValue(ValueNode node, C context) {
+      preChildren(node, context);
       visitChildren(node, context);
-      return defaultCase(node, context);
+      return postChildren(node, context);
     }
 
     @Override
     public T visitWindow(WindowNode node, C context) {
+      preChildren(node, context);
       visitChildren(node, context);
-      return defaultCase(node, context);
+      return postChildren(node, context);
     }
 
     @Override
     public T visitSetOp(SetOpNode node, C context) {
+      preChildren(node, context);
       visitChildren(node, context);
-      return defaultCase(node, context);
+      return postChildren(node, context);
     }
 
     @Override
     public T visitExchange(ExchangeNode node, C context) {
+      preChildren(node, context);
       visitChildren(node, context);
-      return defaultCase(node, context);
+      return postChildren(node, context);
     }
 
     @Override
     public T visitExplained(ExplainedNode node, C context) {
+      preChildren(node, context);
       visitChildren(node, context);
-      return defaultCase(node, context);
+      return postChildren(node, context);
     }
   }
 }
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinderTest.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinderTest.java
index 54f101059b..1e7f71f396 100644
--- 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinderTest.java
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinderTest.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.query.planner.logical;
 
 import java.util.Map;
+import org.apache.calcite.rel.RelDistribution;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.plannode.MailboxSendNode;
 import org.testng.annotations.Test;
@@ -68,6 +69,34 @@ public class EquivalentStagesFinderTest extends 
StagesTestBase {
     assertEquals(result.toString(), "[[0], [1, 2]]");
   }
 
+  @Test
+  void sameDistributionKeepEquivalence() {
+    when(
+        join(
+            exchange(1, tableScan("T1"))
+                .withDistributionType(RelDistribution.Type.RANDOM_DISTRIBUTED),
+            exchange(2, tableScan("T1"))
+                .withDistributionType(RelDistribution.Type.RANDOM_DISTRIBUTED)
+        )
+    );
+    GroupedStages groupedStages = 
EquivalentStagesFinder.findEquivalentStages(stage(0));
+    assertEquals(groupedStages.toString(), "[[0], [1, 2]]");
+  }
+
+  @Test
+  void differentDistributionBreakEquivalence() {
+    when(
+        join(
+            exchange(1, tableScan("T1"))
+                .withDistributionType(RelDistribution.Type.RANDOM_DISTRIBUTED),
+            exchange(2, tableScan("T1"))
+                
.withDistributionType(RelDistribution.Type.BROADCAST_DISTRIBUTED)
+        )
+    );
+    GroupedStages groupedStages = 
EquivalentStagesFinder.findEquivalentStages(stage(0));
+    assertEquals(groupedStages.toString(), "[[0], [1], [2]]");
+  }
+
   @Test
   public void sameHintsDontBreakEquivalence() {
     when(
@@ -89,7 +118,7 @@ public class EquivalentStagesFinderTest extends 
StagesTestBase {
   }
 
   @Test
-  public void differentHintsImplyNotEquivalent() {
+  public void differentHintsBreakEquivalence() {
     when(
         join(
             exchange(
@@ -109,7 +138,7 @@ public class EquivalentStagesFinderTest extends 
StagesTestBase {
   }
 
   @Test
-  public void differentHintsOneNullImplyNotEquivalent() {
+  public void differentHintsOneNullBreakEquivalence() {
     when(
         join(
             exchange(1, tableScan("T1")),
@@ -199,4 +228,18 @@ public class EquivalentStagesFinderTest extends 
StagesTestBase {
     GroupedStages result = 
EquivalentStagesFinder.findEquivalentStages(stage(0));
     assertEquals(result.toString(), "[[0], [1, 2], [3, 5], [4, 6]]");
   }
+
+  @Test
+  void notUniqueReceiversInStage() {
+    when(// stage 0
+        exchange(1,
+            join(
+                exchange(2, tableScan("T1")),
+                exchange(3, tableScan("T1"))
+            )
+        )
+    );
+    GroupedStages groupedStages = 
EquivalentStagesFinder.findEquivalentStages(stage(0));
+    assertEquals(groupedStages.toString(), "[[0], [1], [2, 3]]");
+  }
 }
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesReplacerTest.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesReplacerTest.java
new file mode 100644
index 0000000000..830c8a2d78
--- /dev/null
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesReplacerTest.java
@@ -0,0 +1,146 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.logical;
+
+import org.apache.pinot.query.planner.plannode.MailboxSendNode;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.*;
+
+
+public class EquivalentStagesReplacerTest extends StagesTestBase {
+
+  @Test
+  public void test() {
+    when(// stage 0
+        exchange(1,
+          join(
+              exchange(2,
+                join(
+                  exchange(3, tableScan("T1")),
+                  exchange(4, tableScan("T2"))
+                )
+              ),
+              exchange(5,
+                join(
+                  exchange(6, tableScan("T1")),
+                  exchange(7, tableScan("T3"))
+                )
+              )
+          )
+        )
+    );
+
+    GroupedStages groupedStages = 
EquivalentStagesFinder.findEquivalentStages(stage(0));
+    assertEquals(groupedStages.toString(), "[[0], [1], [2], [3, 6], [4], [5], 
[7]]");
+
+    MailboxSendNode rootStage = stage(0);
+    EquivalentStagesReplacer.replaceEquivalentStages(rootStage, groupedStages);
+
+    cleanup();
+    SpoolBuilder readT1 = new SpoolBuilder(3, tableScan("T1"));
+    MailboxSendNode expected = when(// stage 0
+        exchange(1,
+            join(
+                exchange(2,
+                    join(
+                        readT1.newReceiver(),
+                        exchange(4, tableScan("T2"))
+                    )
+                ),
+                exchange(5,
+                    join(
+                        readT1.newReceiver(),
+                        exchange(7, tableScan("T3"))
+                    )
+                )
+            )
+        )
+    );
+
+    assertEqualPlan(rootStage, expected);
+  }
+
+  @Test
+  void notUniqueReceiversInStage() {
+    when(// stage 0
+        exchange(1,
+            join(
+                exchange(2, tableScan("T1")),
+                exchange(3, tableScan("T1"))
+            )
+        )
+    );
+    GroupedStages groupedStages = 
EquivalentStagesFinder.findEquivalentStages(stage(0));
+    assertEquals(groupedStages.toString(), "[[0], [1], [2, 3]]");
+
+    MailboxSendNode rootStage = stage(0);
+    EquivalentStagesReplacer.replaceEquivalentStages(rootStage, groupedStages);
+
+    cleanup();
+    MailboxSendNode expected = when(// stage 0
+        exchange(1,
+            join(
+                exchange(2, tableScan("T1")),
+                exchange(3, tableScan("T1"))
+            )
+        )
+    );
+    assertEqualPlan(rootStage, expected);
+  }
+
+  @Test
+  void groupSendingToTheSameStage() {
+    when(// stage 0
+        exchange(1,
+            join(
+                exchange(2, tableScan("T1")),
+                exchange(3,
+                    join(
+                        exchange(4, tableScan("T1")),
+                        exchange(5, tableScan("T1"))
+                    )
+                )
+            )
+        )
+    );
+    GroupedStages groupedStages = 
EquivalentStagesFinder.findEquivalentStages(stage(0));
+    assertEquals(groupedStages.toString(), "[[0], [1], [2, 4, 5], [3]]");
+
+    MailboxSendNode rootStage = stage(0);
+    EquivalentStagesReplacer.replaceEquivalentStages(rootStage, groupedStages);
+
+    cleanup();
+    SpoolBuilder readT1 = new SpoolBuilder(2, tableScan("T1"));
+    MailboxSendNode expected = when(// stage 0
+        exchange(1,
+            join(
+                readT1.newReceiver(),
+                exchange(3,
+                    join(
+                        readT1.newReceiver(),
+                        exchange(5, tableScan("T1"))
+                    )
+                )
+            )
+        )
+    );
+    assertEqualPlan(rootStage, expected);
+  }
+}
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/StagesTestBase.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/StagesTestBase.java
index 93fc109583..3735a82976 100644
--- 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/StagesTestBase.java
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/StagesTestBase.java
@@ -18,20 +18,29 @@
  */
 package org.apache.pinot.query.planner.logical;
 
+import com.google.common.base.Function;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableMap;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.IdentityHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
 import javax.annotation.Nullable;
+import org.apache.calcite.rel.RelDistribution;
+import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.pinot.calcite.rel.logical.PinotRelExchangeType;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.plannode.JoinNode;
 import org.apache.pinot.query.planner.plannode.MailboxReceiveNode;
 import org.apache.pinot.query.planner.plannode.MailboxSendNode;
 import org.apache.pinot.query.planner.plannode.PlanNode;
+import org.apache.pinot.query.planner.plannode.PlanNodeVisitor;
 import org.apache.pinot.query.planner.plannode.TableScanNode;
+import org.testng.Assert;
 import org.testng.annotations.AfterMethod;
 
 
@@ -116,19 +125,41 @@ public class StagesTestBase {
    * Although there are builder methods to create send and receive mailboxes 
separately, this method is recommended
    * because it deals with the stageId management and creates tests that are 
easier to read.
    */
-  public SimpleChildBuilder<MailboxReceiveNode> exchange(
+  public ExchangeBuilder exchange(
       int nextStageId, SimpleChildBuilder<? extends PlanNode> childBuilder) {
-    return (stageId, mySchema, myHints) -> {
-      PlanNode input = childBuilder.build(nextStageId);
-      MailboxSendNode mailboxSendNode = new MailboxSendNode(nextStageId, null, 
List.of(input), stageId, null, null,
-          null, false, null, false);
-      MailboxSendNode old = _stageRoots.put(nextStageId, mailboxSendNode);
-      Preconditions.checkState(old == null, "Mailbox already exists for 
stageId: %s", nextStageId);
-      return new MailboxReceiveNode(stageId, null, nextStageId, null, null, 
null, null,
-          false, false, mailboxSendNode);
+    return new ExchangeBuilder() {
+      @Override
+      public MailboxReceiveNode build(int stageId, DataSchema dataSchema, 
PlanNode.NodeHint hints,
+          PinotRelExchangeType exchangeType, RelDistribution.Type 
distribution, List<Integer> keys,
+          boolean prePartitioned, List<RelFieldCollation> collations, boolean 
sort, boolean sortedOnSender) {
+        PlanNode input = childBuilder.build(nextStageId);
+        MailboxSendNode mailboxSendNode = new MailboxSendNode(nextStageId, 
input.getDataSchema(), List.of(input),
+            stageId, exchangeType, distribution, keys, prePartitioned, 
collations, sort);
+        MailboxSendNode old = _stageRoots.put(nextStageId, mailboxSendNode);
+        Preconditions.checkState(old == null, "Mailbox already exists for 
stageId: %s", nextStageId);
+        return new MailboxReceiveNode(stageId, input.getDataSchema(), 
nextStageId, exchangeType, distribution, keys,
+            collations, sort, sortedOnSender, mailboxSendNode);
+      }
     };
   }
 
+  public interface ExchangeBuilder extends 
SimpleChildBuilder<MailboxReceiveNode> {
+    MailboxReceiveNode build(int stageId, DataSchema dataSchema, 
PlanNode.NodeHint hints,
+        PinotRelExchangeType exchangeType, RelDistribution.Type distribution, 
List<Integer> keys,
+        boolean prePartitioned, List<RelFieldCollation> collations, boolean 
sort, boolean sortedOnSender);
+
+    default MailboxReceiveNode build(int stageId, DataSchema dataSchema, 
PlanNode.NodeHint hints) {
+      return build(stageId, null, null, null, null, null, false, null, false, 
false);
+    }
+
+    default ExchangeBuilder withDistributionType(RelDistribution.Type 
distribution) {
+      return (stageId, dataSchema, hints, exchangeType, distribution1, keys, 
prePartitioned, collations, sort,
+          sortedOnSender) ->
+        build(stageId, dataSchema, hints, exchangeType, distribution, keys, 
prePartitioned, collations, sort,
+            sortedOnSender);
+    }
+  }
+
   /**
    * Creates a table scan node with the given table name.
    */
@@ -159,8 +190,8 @@ public class StagesTestBase {
       int newStageId, SimpleChildBuilder<? extends PlanNode> childBuilder) {
     return (stageId, mySchema, myHints) -> {
       PlanNode input = childBuilder.build(stageId);
-      MailboxSendNode mailboxSendNode = new MailboxSendNode(newStageId, 
mySchema, List.of(input), stageId, null, null,
-          null, false, null, false);
+      MailboxSendNode mailboxSendNode = new MailboxSendNode(newStageId, 
mySchema, List.of(input), stageId, null,
+          null, null, false, null, false);
       MailboxSendNode old = _stageRoots.put(stageId, mailboxSendNode);
       Preconditions.checkState(old == null, "Mailbox already exists for 
stageId: %s", stageId);
       return mailboxSendNode;
@@ -229,4 +260,216 @@ public class StagesTestBase {
       return build(stageId, null, null);
     }
   }
+
+  /**
+   * A helper class that can be used to create a spool in the context of a 
test.
+   * <p>
+   * These spools are used to create a single sender that will send data to 
multiple receivers.
+   * This class is just a helper to make it easier to create the sender and 
the receivers in a single fluent way during
+   * a test. A spool breaks by definition the idea that plan nodes are 
tree-like. Instead once spools are used, the
+   * plan nodes are a directed graph that should not have cycles. The latter 
is not enforced by this class but a
+   * responsibility of the test writer.
+   * <p>
+   * Graphs are more complex to write in a nice readable way and require some 
mutation on the nodes that are created.
+   * In order to help, this class has two states: the initial state and the 
sealed state. When a new spool is created,
+   * it is in the initial state and can {@link #newReceiver()} can be called 
multiple times to create multiple
+   * receivers. Once one of these receivers is built, the spool is sealed and 
no more receivers can be created.
+   * <p>
+   * Usually this class should be used in the following manner:
+   * <p>
+   * <pre>
+   *   Spool readT1 = new Spool(3, tableScan("T1")); // here the spool is 
created
+   *   ExchangeBuilder builder = exchange(1,
+   *     join(
+   *       readT1.newReceiver(), // here a new receiver is created
+   *       readT1.newReceiver() // another receiver is created
+   *     )
+   *   );
+   *   // here the builder is called, which recursively calls the build method 
on the receivers, which seals the spool
+   *   when(builder);
+   * </pre>
+   * <p>
+   *
+   * Notice that usually the builder is not stored as a variable but directly 
used as argument to when. For example,
+   * {@code when(exchange(1, ...));}. This is completely fine and recommended. 
The snippet above splits the creation of
+   * the builder from the call to when to make it easier to understand the 
flow of the test.
+   * <p>
+   * This means that if more than one spool is needed in a test, the test 
writer should create multiple instances of
+   * this class.
+   */
+  public static class SpoolBuilder {
+    private final int _senderStageId;
+    /**
+     * The set of receiver builders. A new element is added every time {@link 
#newReceiver()} is called.
+     * When the first builder is built, {@link #seal()} is called, which 
creates the sender node.
+     */
+    private final Set<SpoolReceiverBuilder> _receiverBuilder = 
Collections.newSetFromMap(new IdentityHashMap<>());
+    private MailboxSendNode _sender;
+    private final SimpleChildBuilder<? extends PlanNode> _childBuilder;
+
+    /**
+     * Creates a new spool with the given sender stage id and child builder.
+     *
+     * The child builder will be used to create the child node that will 
generate the data that will be sent to the
+     * multiple receivers.
+     */
+    public SpoolBuilder(int senderStageId, SimpleChildBuilder<? extends 
PlanNode> spoolChildBuilder) {
+      _senderStageId = senderStageId;
+      _childBuilder = spoolChildBuilder;
+    }
+
+    /**
+     * Returns the sender node for this spool.
+     *
+     * This method can only be called after the spool is sealed, otherwise the 
sender won't be available and this method
+     * will fail with an exception.
+     */
+    public MailboxSendNode getSender() {
+      Preconditions.checkState(isSealed(), "Spool not sealed");
+      return _sender;
+    }
+
+    /**
+     * Returns whether the spool is sealed or not.
+     */
+    public boolean isSealed() {
+      return _sender != null;
+    }
+
+    /**
+     * Creates a new receiver builder that can be used to create a new 
receiver for this spool.
+     *
+     * This method is similar to other builder methods (like {@link 
#tableScan(String)} or
+     * {@link #join(SimpleChildBuilder, SimpleChildBuilder)}) and can be 
called multiple times to create multiple
+     * receivers.
+     *
+     * In most scenarios, the overloaded method {@link #newReceiver()} is good 
enough. This method is useful when the
+     * test writer wants to customize the receiver in some way (for example, 
changing the data schema or hints).
+     * The customize function will be called with a base builder that creates 
the receiver with the same data schema
+     * as the server and no hints.
+     */
+    public SimpleChildBuilder<MailboxReceiveNode> newReceiver(
+        Function<SimpleChildBuilder<MailboxReceiveNode>, 
SimpleChildBuilder<MailboxReceiveNode>> customize) {
+      Preconditions.checkState(!isSealed(), "Spool already sealed");
+
+      SpoolReceiverBuilder spoolReceiverBuilder = new 
SpoolReceiverBuilder(customize);
+
+      _receiverBuilder.add(spoolReceiverBuilder);
+      return spoolReceiverBuilder;
+    }
+
+
+    /**
+     * Creates a new receiver builder that can be used to create a new 
receiver for this spool.
+     *
+     * This method is similar to other builder methods (like {@link 
#tableScan(String)} or
+     * {@link #join(SimpleChildBuilder, SimpleChildBuilder)}) and can be 
called multiple times to create multiple
+     * receivers.
+     *
+     * This method creates a receiver with the same data schema as the sender 
and no hints. In case the test writer
+     * wants to customize the receiver, the method {@link 
#newReceiver(Function)} should be used.
+     */
+    public SimpleChildBuilder<MailboxReceiveNode> newReceiver() {
+      return newReceiver(a -> a);
+    }
+
+    private void seal() {
+      if (isSealed()) { // for simplicity the seal method may be called 
multiple times
+        return;
+      }
+
+      PlanNode input = _childBuilder.build(_senderStageId);
+      DataSchema mySchema = input.getDataSchema();
+      _sender = new MailboxSendNode(_senderStageId, mySchema, List.of(input), 
null,
+          null, null, false, null, false);
+    }
+
+    /**
+     * This is the internal class returned as a result of the {@link 
#newReceiver(Function)} method.
+     *
+     * They don't just create the receiver, but also end up sealing the spool 
and modify the sender to add the receiver
+     * to the list of receivers.
+     */
+    private class SpoolReceiverBuilder implements 
SimpleChildBuilder<MailboxReceiveNode> {
+      @Nullable
+      private MailboxReceiveNode _receiver;
+      private final Function<SimpleChildBuilder<MailboxReceiveNode>, 
SimpleChildBuilder<MailboxReceiveNode>> _customize;
+
+      public SpoolReceiverBuilder(
+          Function<SimpleChildBuilder<MailboxReceiveNode>, 
SimpleChildBuilder<MailboxReceiveNode>> customize) {
+        _customize = customize;
+      }
+
+      @Override
+      public MailboxReceiveNode build(int stageId, @Nullable DataSchema 
dataSchema, @Nullable PlanNode.NodeHint hints) {
+        Preconditions.checkState(dataSchema == null, "Data schema for spool 
must be set internally");
+        Preconditions.checkState(hints == null, "Hints for spool must be set 
internally");
+        if (_receiver == null) {
+          seal();
+          SimpleChildBuilder<MailboxReceiveNode> baseBuilder = 
(currentStageId, ignoreSchema, ignoreHints) -> {
+            DataSchema mySchema = _sender.getDataSchema();
+            return new MailboxReceiveNode(currentStageId, mySchema, 
_senderStageId, null, null, null, null, false,
+                false, _sender);
+          };
+          SimpleChildBuilder<MailboxReceiveNode> receiveBuilder = 
_customize.apply(baseBuilder);
+          _receiver = receiveBuilder.build(stageId);
+          _sender.addReceiver(_receiver);
+        }
+        Preconditions.checkState(_receiver.getStageId() == stageId, "Receiver 
stageId mismatch. "
+                + "Expected %s, received %s", _receiver.getStageId(), stageId);
+        assert _receiver != null;
+        return _receiver;
+      }
+    }
+  }
+
+  public void assertEqualPlan(PlanNode actual, PlanNode expected) {
+    if (expected == null || actual == null) {
+      if (expected == null && actual == null) {
+        return;
+      }
+      throw new AssertionError("Expected: \n" + expected + ", actual: \n" + 
actual);
+    }
+    if (Objects.equals(expected, actual)) {
+      return;
+    }
+    Assert.fail("Expected: \n" + explainNode(expected) + ", actual: \n" + 
explainNode(actual));
+  }
+
+  private String explainNode(PlanNode node) {
+    StringBuilder sb = new StringBuilder();
+    NodePrinter nodePrinter = new NodePrinter(sb);
+    node.visit(nodePrinter, null);
+    return sb.toString();
+  }
+
+  private static class NodePrinter extends 
PlanNodeVisitor.DepthFirstVisitor<Void, Void> {
+    private final StringBuilder _builder;
+    private int _indent;
+
+    public NodePrinter(StringBuilder builder) {
+      _builder = builder;
+    }
+
+    @Override
+    protected Void preChildren(PlanNode node, Void context) {
+      int stageId = node.getStageId();
+      for (int i = 0; i < _indent; i++) {
+        _builder.append("  ");
+      }
+      _builder.append('[')
+          .append(stageId)
+          .append("]: ")
+          .append(node.explain())
+          .append('\n');
+      _indent++;
+      return null;
+    }
+
+    @Override
+    protected Void postChildren(PlanNode node, Void context) {
+      _indent--;
+      return super.postChildren(node, context);
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to