This is an automated email from the ASF dual-hosted git repository.

caogaofei pushed a commit to branch beyyes/join
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit f7fd85e0bde165b4fcbec0237c0c0dd5fd01b701
Author: Beyyes <[email protected]>
AuthorDate: Tue Aug 13 17:14:24 2024 +0800

    add basic join support
---
 .../plan/planner/plan/node/PlanVisitor.java        |   5 +
 .../planner/iterative/rule/PruneJoinColumns.java   |  37 +++++++
 .../plan/relational/planner/node/JoinNode.java     | 116 ++++++++++++++++++---
 .../plan/relational/planner/node/Patterns.java     |   9 +-
 .../optimizations/LogicalOptimizeFactory.java      |   4 +-
 .../optimizations/PushPredicateIntoTableScan.java  |   8 ++
 6 files changed, 156 insertions(+), 23 deletions(-)

diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/PlanVisitor.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/PlanVisitor.java
index 3bc791fce21..20666c17178 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/PlanVisitor.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/PlanVisitor.java
@@ -630,6 +630,11 @@ public abstract class PlanVisitor<R, C> {
     return visitMultiChildProcess(node, context);
   }
 
+  public R visitJoin(
+      org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode 
node, C context) {
+    return visitTwoChildProcess(node, context);
+  }
+
   public R visitGroupReference(GroupReference node, C context) {
     return visitPlan(node, context);
   }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PruneJoinColumns.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PruneJoinColumns.java
new file mode 100644
index 00000000000..5c2592ac04b
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PruneJoinColumns.java
@@ -0,0 +1,37 @@
+package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule;
+
+import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode;
+
+import java.util.Optional;
+import java.util.Set;
+
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.join;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.utils.MoreLists.filteredCopy;
+
+/** Joins support output symbol selection, so absorb any project-off into the 
node. */
+public class PruneJoinColumns extends ProjectOffPushDownRule<JoinNode> {
+  public PruneJoinColumns() {
+    super(join());
+  }
+
+  @Override
+  protected Optional<PlanNode> pushDownProjectOff(
+      Context context, JoinNode joinNode, Set<Symbol> referencedOutputs) {
+    return Optional.of(
+        new JoinNode(
+            joinNode.getPlanNodeId(),
+            joinNode.getJoinType(),
+            joinNode.getLeftChild(),
+            joinNode.getRightChild(),
+            joinNode.getCriteria(),
+            filteredCopy(joinNode.getLeftOutputSymbols(), 
referencedOutputs::contains),
+            filteredCopy(joinNode.getRightOutputSymbols(), 
referencedOutputs::contains),
+            joinNode.isMaySkipOutputDuplicates(),
+            joinNode.getFilter(),
+            joinNode.getLeftHashSymbol(),
+            joinNode.getRightHashSymbol(),
+            joinNode.isSpillable()));
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/JoinNode.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/JoinNode.java
index e412d8182c4..25eafc4c626 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/JoinNode.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/JoinNode.java
@@ -2,7 +2,8 @@ package 
org.apache.iotdb.db.queryengine.plan.relational.planner.node;
 
 import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
 import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNodeId;
-import 
org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.MultiChildProcessNode;
+import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanVisitor;
+import 
org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.TwoChildProcessNode;
 import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
 import 
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
 import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
@@ -25,11 +26,9 @@ import static 
com.google.common.base.Preconditions.checkArgument;
 import static java.lang.String.format;
 import static java.util.Objects.requireNonNull;
 
-public class JoinNode extends MultiChildProcessNode {
+public class JoinNode extends TwoChildProcessNode {
 
-  private final JoinType type;
-  private final PlanNode left;
-  private final PlanNode right;
+  private final JoinType joinType;
   private final List<EquiJoinClause> criteria;
   private final List<Symbol> leftOutputSymbols;
   private final List<Symbol> rightOutputSymbols;
@@ -45,9 +44,9 @@ public class JoinNode extends MultiChildProcessNode {
   @JsonCreator
   public JoinNode(
       @JsonProperty("id") PlanNodeId id,
-      @JsonProperty("type") JoinType type,
-      @JsonProperty("left") PlanNode left,
-      @JsonProperty("right") PlanNode right,
+      @JsonProperty("type") JoinType joinType,
+      @JsonProperty("left") PlanNode leftChild,
+      @JsonProperty("right") PlanNode rightChild,
       @JsonProperty("criteria") List<EquiJoinClause> criteria,
       @JsonProperty("leftOutputSymbols") List<Symbol> leftOutputSymbols,
       @JsonProperty("rightOutputSymbols") List<Symbol> rightOutputSymbols,
@@ -62,9 +61,9 @@ public class JoinNode extends MultiChildProcessNode {
         // reorderJoinStatsAndCost)
       {
     super(id);
-    requireNonNull(type, "type is null");
-    requireNonNull(left, "left is null");
-    requireNonNull(right, "right is null");
+    requireNonNull(joinType, "type is null");
+    requireNonNull(leftChild, "left is null");
+    requireNonNull(rightChild, "right is null");
     requireNonNull(criteria, "criteria is null");
     requireNonNull(leftOutputSymbols, "leftOutputSymbols is null");
     requireNonNull(rightOutputSymbols, "rightOutputSymbols is null");
@@ -81,9 +80,9 @@ public class JoinNode extends MultiChildProcessNode {
     // requireNonNull(distributionType, "distributionType is null");
     requireNonNull(spillable, "spillable is null");
 
-    this.type = type;
-    this.left = left;
-    this.right = right;
+    this.joinType = joinType;
+    this.leftChild = leftChild;
+    this.rightChild = rightChild;
     this.criteria = ImmutableList.copyOf(criteria);
     this.leftOutputSymbols = ImmutableList.copyOf(leftOutputSymbols);
     this.rightOutputSymbols = ImmutableList.copyOf(rightOutputSymbols);
@@ -98,8 +97,8 @@ public class JoinNode extends MultiChildProcessNode {
     // this.reorderJoinStatsAndCost = requireNonNull(reorderJoinStatsAndCost,
     // "reorderJoinStatsAndCost is null");
 
-    Set<Symbol> leftSymbols = ImmutableSet.copyOf(left.getOutputSymbols());
-    Set<Symbol> rightSymbols = ImmutableSet.copyOf(right.getOutputSymbols());
+    Set<Symbol> leftSymbols = 
ImmutableSet.copyOf(leftChild.getOutputSymbols());
+    Set<Symbol> rightSymbols = 
ImmutableSet.copyOf(rightChild.getOutputSymbols());
 
     checkArgument(
         leftSymbols.containsAll(leftOutputSymbols),
@@ -140,9 +139,56 @@ public class JoinNode extends MultiChildProcessNode {
     //        }
   }
 
+  @Override
+  public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
+    return visitor.visitJoin(this, context);
+  }
+
+  @Override
+  public PlanNode replaceChildren(List<PlanNode> newChildren) {
+    checkArgument(newChildren.size() == 2, "expected newChildren to contain 2 
nodes for JoinNode");
+    return new JoinNode(
+        getPlanNodeId(),
+        joinType,
+        newChildren.get(0),
+        newChildren.get(1),
+        criteria,
+        leftOutputSymbols,
+        rightOutputSymbols,
+        maySkipOutputDuplicates,
+        filter,
+        leftHashSymbol,
+        rightHashSymbol,
+        spillable);
+  }
+
+  @Override
+  public List<Symbol> getOutputSymbols() {
+    return ImmutableList.<Symbol>builder()
+        .addAll(leftOutputSymbols)
+        .addAll(rightOutputSymbols)
+        .build();
+  }
+
   @Override
   public PlanNode clone() {
-    return null;
+    JoinNode joinNode =
+        new JoinNode(
+            getPlanNodeId(),
+            joinType,
+            getLeftChild(),
+            getRightChild(),
+            criteria,
+            leftOutputSymbols,
+            rightOutputSymbols,
+            maySkipOutputDuplicates,
+            filter,
+            leftHashSymbol,
+            rightHashSymbol,
+            spillable);
+    joinNode.setLeftChild(null);
+    joinNode.setRightChild(null);
+    return joinNode;
   }
 
   @Override
@@ -156,6 +202,42 @@ public class JoinNode extends MultiChildProcessNode {
   @Override
   protected void serializeAttributes(DataOutputStream stream) throws 
IOException {}
 
+  public JoinType getJoinType() {
+    return joinType;
+  }
+
+  public List<EquiJoinClause> getCriteria() {
+    return criteria;
+  }
+
+  public List<Symbol> getLeftOutputSymbols() {
+    return leftOutputSymbols;
+  }
+
+  public List<Symbol> getRightOutputSymbols() {
+    return rightOutputSymbols;
+  }
+
+  public boolean isMaySkipOutputDuplicates() {
+    return maySkipOutputDuplicates;
+  }
+
+  public Optional<Expression> getFilter() {
+    return filter;
+  }
+
+  public Optional<Symbol> getLeftHashSymbol() {
+    return leftHashSymbol;
+  }
+
+  public Optional<Symbol> getRightHashSymbol() {
+    return rightHashSymbol;
+  }
+
+  public Optional<Boolean> isSpillable() {
+    return spillable;
+  }
+
   public static class EquiJoinClause {
     private final Symbol left;
     private final Symbol right;
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/Patterns.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/Patterns.java
index 44c1854c9cd..8c0a37e5f43 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/Patterns.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/Patterns.java
@@ -84,6 +84,10 @@ public final class Patterns {
     return typeOf(FilterNode.class);
   }
 
+  public static Pattern<JoinNode> join() {
+    return typeOf(JoinNode.class);
+  }
+
   /*public static Pattern<IndexJoinNode> indexJoin()
   {
       return typeOf(IndexJoinNode.class);
@@ -94,11 +98,6 @@ public final class Patterns {
       return typeOf(IndexSourceNode.class);
   }
 
-  public static Pattern<JoinNode> join()
-  {
-      return typeOf(JoinNode.class);
-  }
-
   public static Pattern<DynamicFilterSourceNode> dynamicFilterSource()
   {
       return typeOf(DynamicFilterSourceNode.class);
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
index b2705b9e65a..d0ed1b6e42e 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
@@ -21,6 +21,7 @@ import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.In
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.MergeLimitOverProjectWithSort;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.MergeLimitWithSort;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneFilterColumns;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneJoinColumns;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneLimitColumns;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneOffsetColumns;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneOutputSourceColumns;
@@ -57,7 +58,8 @@ public class LogicalOptimizeFactory {
             new PruneProjectColumns(),
             new PruneSortColumns(),
             new PruneTableScanColumns(plannerContext.getMetadata()),
-            new PruneTopKColumns());
+            new PruneTopKColumns(),
+            new PruneJoinColumns());
     IterativeOptimizer columnPruningOptimizer =
         new IterativeOptimizer(plannerContext, new RuleStatsRecorder(), 
columnPruningRules);
 
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
index 0fc783013ce..e46c298eb00 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
@@ -24,6 +24,7 @@ import 
org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
 import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanVisitor;
 import 
org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.MultiChildProcessNode;
 import 
org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.SingleChildProcessNode;
+import 
org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.TwoChildProcessNode;
 import 
org.apache.iotdb.db.queryengine.plan.planner.plan.node.write.InsertTabletNode;
 import 
org.apache.iotdb.db.queryengine.plan.planner.plan.node.write.RelationalInsertTabletNode;
 import org.apache.iotdb.db.queryengine.plan.relational.analyzer.Analysis;
@@ -114,6 +115,13 @@ public class PushPredicateIntoTableScan implements 
PlanOptimizer {
       return node;
     }
 
+    @Override
+    public PlanNode visitTwoChildProcess(TwoChildProcessNode node, Void 
context) {
+      node.setLeftChild(node.getLeftChild().accept(this, context));
+      node.setRightChild(node.getRightChild().accept(this, context));
+      return node;
+    }
+
     @Override
     public PlanNode visitMultiChildProcess(MultiChildProcessNode node, Void 
context) {
       List<PlanNode> rewrittenChildren = new ArrayList<>();

Reply via email to