This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new db8bc80c36 [feature](Nereids): semi join transpose (#12590)
db8bc80c36 is described below
commit db8bc80c36629c0c5d71f6a7cd78aefeb6733e2a
Author: jakevin <[email protected]>
AuthorDate: Thu Sep 15 21:32:50 2022 +0800
[feature](Nereids): semi join transpose (#12590)
* [feature](Nereids): semi join transpose and enable ZIG_ZAG join reorder.
---
.../org/apache/doris/nereids/rules/RuleSet.java | 11 ++
.../org/apache/doris/nereids/rules/RuleType.java | 7 +-
.../rules/exploration/join/JoinLAsscomProject.java | 2 +-
.../join/SemiJoinLogicalJoinTranspose.java | 50 ++++++--
.../join/SemiJoinLogicalJoinTransposeProject.java | 111 ++++++++++-------
.../join/SemiJoinSemiJoinTranspose.java | 3 +-
.../logical/PushdownProjectThroughLimit.java | 2 +-
.../SemiJoinLogicalJoinTransposeProjectTest.java | 134 +++++++++++++++++++++
.../join/SemiJoinLogicalJoinTransposeTest.java | 126 +++++++++++++++++++
...est.java => SemiJoinSemiJoinTransposeTest.java} | 13 +-
.../doris/nereids/util/LogicalPlanBuilder.java | 6 +-
11 files changed, 394 insertions(+), 71 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
index cf2c0ce311..7f11734e1f 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
@@ -21,6 +21,9 @@ import
org.apache.doris.nereids.rules.exploration.join.JoinCommute;
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject;
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscom;
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscomProject;
+import
org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTranspose;
+import
org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTransposeProject;
+import
org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTranspose;
import
org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg;
import
org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
import
org.apache.doris.nereids.rules.implementation.LogicalEmptyRelationToPhysicalEmptyRelation;
@@ -55,6 +58,9 @@ public class RuleSet {
.add(JoinCommuteProject.LEFT_DEEP)
.add(JoinLAsscom.INNER)
.add(JoinLAsscomProject.INNER)
+ .add(SemiJoinLogicalJoinTranspose.LEFT_DEEP)
+ .add(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP)
+ .add(SemiJoinSemiJoinTranspose.INSTANCE)
.add(new PushdownFilterThroughProject())
.add(new MergeConsecutiveProjects())
.build();
@@ -140,6 +146,11 @@ public class RuleSet {
return this;
}
+ public RuleFactories addAll(List<Rule> rules) {
+ this.rules.addAll(rules);
+ return this;
+ }
+
public List<Rule> build() {
return rules.build();
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index c9fe816970..c40419dfea 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -109,8 +109,7 @@ public enum RuleType {
OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE),
// Pushdown filter
PUSHDOWN_FILTER_THROUGH_PROJET(RuleTypeClass.REWRITE),
- LOGICAL_LIMIT_TO_LOGICAL_EMPTY_RELATION_RULE(RuleTypeClass.REWRITE),
- SWAP_LIMIT_PROJECT(RuleTypeClass.REWRITE),
+ PUSHDOWN_PROJECT_THROUGHT_LIMIT(RuleTypeClass.REWRITE),
REWRITE_SENTINEL(RuleTypeClass.REWRITE),
// limit push down
@@ -122,7 +121,11 @@ public enum RuleType {
LOGICAL_JOIN_COMMUTATE(RuleTypeClass.EXPLORATION),
LOGICAL_LEFT_JOIN_ASSOCIATIVE(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_L_ASSCOM(RuleTypeClass.EXPLORATION),
+ LOGICAL_JOIN_L_ASSCOM_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_EXCHANGE(RuleTypeClass.EXPLORATION),
+ LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
+
LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
+ LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE(RuleTypeClass.EXPLORATION),
// implementation rules
LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java
index 5bbd120b52..8c45afaa04 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java
@@ -75,6 +75,6 @@ public class JoinLAsscomProject extends
OneExplorationRuleFactory {
return null;
}
return helper.newTopJoin();
- }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
+ }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM_PROJECT);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java
index f5c5580d8d..b9ebadfae1 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
+import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
@@ -27,6 +28,7 @@ import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
+import java.util.List;
import java.util.Set;
/**
@@ -42,9 +44,21 @@ import java.util.Set;
* which operands actually participate in the semi-join.
*/
public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
+
+ public static final SemiJoinLogicalJoinTranspose LEFT_DEEP = new
SemiJoinLogicalJoinTranspose(true);
+
+ public static final SemiJoinLogicalJoinTranspose ALL = new
SemiJoinLogicalJoinTranspose(false);
+
+ private final boolean leftDeep;
+
+ public SemiJoinLogicalJoinTranspose(boolean leftDeep) {
+ this.leftDeep = leftDeep;
+ }
+
@Override
public Rule build() {
return leftSemiLogicalJoin(logicalJoin(), group())
+ .whenNot(topJoin ->
topJoin.left().getJoinType().isSemiOrAntiJoin())
.when(this::conditionChecker)
.then(topSemiJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin =
topSemiJoin.left();
@@ -52,7 +66,14 @@ public class SemiJoinLogicalJoinTranspose extends
OneExplorationRuleFactory {
GroupPlan b = bottomJoin.right();
GroupPlan c = topSemiJoin.right();
- boolean lasscom =
bottomJoin.getOutputSet().containsAll(a.getOutput());
+ List<Expression> hashJoinConjuncts =
topSemiJoin.getHashJoinConjuncts();
+ Set<Slot> aOutputSet = a.getOutputSet();
+
+ boolean lasscom = false;
+ for (Expression hashJoinConjunct : hashJoinConjuncts) {
+ Set<Slot> usedSlot =
hashJoinConjunct.collect(Slot.class::isInstance);
+ lasscom = ExpressionUtils.isIntersecting(usedSlot,
aOutputSet) || lasscom;
+ }
if (lasscom) {
/*
@@ -81,20 +102,27 @@ public class SemiJoinLogicalJoinTranspose extends
OneExplorationRuleFactory {
return new LogicalJoin<>(bottomJoin.getJoinType(),
bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinCondition(), a,
newBottomSemiJoin);
}
- }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
+ }).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE);
}
// bottomJoin just return A OR B, else return false.
- private boolean conditionChecker(LogicalJoin<LogicalJoin<GroupPlan,
GroupPlan>, GroupPlan> topJoin) {
- Set<Slot> bottomOutputSet = topJoin.left().getOutputSet();
-
- Set<Slot> aOutputSet = topJoin.left().left().getOutputSet();
- Set<Slot> bOutputSet = topJoin.left().right().getOutputSet();
+ private boolean conditionChecker(LogicalJoin<LogicalJoin<GroupPlan,
GroupPlan>, GroupPlan> topSemiJoin) {
+ List<Expression> hashJoinConjuncts =
topSemiJoin.getHashJoinConjuncts();
- boolean isProjectA = !ExpressionUtils.isIntersecting(bottomOutputSet,
aOutputSet);
- boolean isProjectB = !ExpressionUtils.isIntersecting(bottomOutputSet,
bOutputSet);
+ List<Slot> aOutput = topSemiJoin.left().left().getOutput();
+ List<Slot> bOutput = topSemiJoin.left().right().getOutput();
- Preconditions.checkState(isProjectA || isProjectB, "join output must
contain child");
- return !(isProjectA && isProjectB);
+ boolean hashContainsA = false;
+ boolean hashContainsB = false;
+ for (Expression hashJoinConjunct : hashJoinConjuncts) {
+ Set<Slot> usedSlot =
hashJoinConjunct.collect(Slot.class::isInstance);
+ hashContainsA = ExpressionUtils.isIntersecting(usedSlot, aOutput)
|| hashContainsA;
+ hashContainsB = ExpressionUtils.isIntersecting(usedSlot, bOutput)
|| hashContainsB;
+ }
+ if (leftDeep && hashContainsB) {
+ return false;
+ }
+ Preconditions.checkState(hashContainsA || hashContainsB, "join output
must contain child");
+ return !(hashContainsA && hashContainsB);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
index 45cdc7a19e..183a1218d1 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
@@ -20,15 +20,17 @@ package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
-import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
+import java.util.ArrayList;
import java.util.List;
import java.util.Set;
@@ -45,9 +47,20 @@ import java.util.Set;
* which operands actually participate in the semi-join.
*/
public class SemiJoinLogicalJoinTransposeProject extends
OneExplorationRuleFactory {
+ public static final SemiJoinLogicalJoinTransposeProject LEFT_DEEP = new
SemiJoinLogicalJoinTransposeProject(true);
+
+ public static final SemiJoinLogicalJoinTransposeProject ALL = new
SemiJoinLogicalJoinTransposeProject(false);
+
+ private final boolean leftDeep;
+
+ public SemiJoinLogicalJoinTransposeProject(boolean leftDeep) {
+ this.leftDeep = leftDeep;
+ }
+
@Override
public Rule build() {
return leftSemiLogicalJoin(logicalProject(logicalJoin()), group())
+ .whenNot(topJoin ->
topJoin.left().child().getJoinType().isSemiOrAntiJoin())
.when(this::conditionChecker)
.then(topSemiJoin -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project
= topSemiJoin.left();
@@ -56,67 +69,77 @@ public class SemiJoinLogicalJoinTransposeProject extends
OneExplorationRuleFacto
GroupPlan b = bottomJoin.right();
GroupPlan c = topSemiJoin.right();
- boolean lasscom =
a.getOutputSet().containsAll(project.getOutput());
+ Set<Slot> aOutputSet = a.getOutputSet();
+
+ List<Expression> hashJoinConjuncts =
topSemiJoin.getHashJoinConjuncts();
+
+ boolean lasscom = false;
+ for (Expression hashJoinConjunct : hashJoinConjuncts) {
+ Set<Slot> usedSlot =
hashJoinConjunct.collect(Slot.class::isInstance);
+ lasscom = ExpressionUtils.isIntersecting(usedSlot,
aOutputSet) || lasscom;
+ }
if (lasscom) {
/*-
- * topSemiJoin newTopProject
- * / \ |
+ * topSemiJoin project
+ * / \ |
* project C newTopJoin
* | -> / \
- * bottomJoin newBottomSemiJoin B
+ * bottomJoin newBottomSemiJoin B
* / \ / \
- * A B aNewProject C
- * |
- * A
+ * A B A C
*/
- List<NamedExpression> projects = project.getProjects();
- LogicalProject<GroupPlan> aNewProject = new
LogicalProject<>(projects, a);
- LogicalJoin<LogicalProject<GroupPlan>, GroupPlan>
newBottomSemiJoin = new LogicalJoin<>(
+ LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin =
new LogicalJoin<>(
topSemiJoin.getJoinType(),
topSemiJoin.getHashJoinConjuncts(),
- topSemiJoin.getOtherJoinCondition(),
aNewProject, c);
- LogicalJoin<LogicalJoin<LogicalProject<GroupPlan>,
GroupPlan>, GroupPlan> newTopJoin
- = new LogicalJoin<>(bottomJoin.getJoinType(),
bottomJoin.getHashJoinConjuncts(),
- bottomJoin.getOtherJoinCondition(),
newBottomSemiJoin, b);
- return new LogicalProject<>(projects, newTopJoin);
+ topSemiJoin.getOtherJoinCondition(), a, c);
+
+ LogicalJoin<Plan, Plan> newTopJoin = new
LogicalJoin<>(bottomJoin.getJoinType(),
+ bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinCondition(),
+ newBottomSemiJoin, b);
+
+ return new LogicalProject<>(new
ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
} else {
/*-
- * topSemiJoin newTopProject
- * / \ |
- * project C newTopJoin
- * | / \
- * bottomJoin C --> A newBottomSemiJoin
- * / \ / \
- * A B bNewProject C
- * |
- * B
+ * topSemiJoin project
+ * / \ |
+ * project C newTopJoin
+ * | / \
+ * bottomJoin C --> A newBottomSemiJoin
+ * / \ / \
+ * A B B C
*/
- List<NamedExpression> projects = project.getProjects();
- LogicalProject<GroupPlan> bNewProject = new
LogicalProject<>(projects, b);
- LogicalJoin<LogicalProject<GroupPlan>, GroupPlan>
newBottomSemiJoin = new LogicalJoin<>(
+ LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin =
new LogicalJoin<>(
topSemiJoin.getJoinType(),
topSemiJoin.getHashJoinConjuncts(),
- topSemiJoin.getOtherJoinCondition(),
bNewProject, c);
+ topSemiJoin.getOtherJoinCondition(), b, c);
+
+ LogicalJoin<Plan, Plan> newTopJoin = new
LogicalJoin<>(bottomJoin.getJoinType(),
+ bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinCondition(),
+ a, newBottomSemiJoin);
- LogicalJoin<GroupPlan,
LogicalJoin<LogicalProject<GroupPlan>, GroupPlan>> newTopJoin
- = new LogicalJoin<>(bottomJoin.getJoinType(),
bottomJoin.getHashJoinConjuncts(),
- bottomJoin.getOtherJoinCondition(), a,
newBottomSemiJoin);
- return new LogicalProject<>(projects, newTopJoin);
+ return new LogicalProject<>(new
ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
}
- }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
+
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT);
}
- // bottomJoin just return A OR B, else return false.
+ // project of bottomJoin just return A OR B, else return false.
private boolean conditionChecker(
- LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>,
GroupPlan> topJoin) {
- Set<Slot> projectOutputSet = topJoin.left().getOutputSet();
-
- Set<Slot> aOutputSet = topJoin.left().child().left().getOutputSet();
- Set<Slot> bOutputSet = topJoin.left().child().right().getOutputSet();
+ LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>,
GroupPlan> topSemiJoin) {
+ List<Expression> hashJoinConjuncts =
topSemiJoin.getHashJoinConjuncts();
- boolean isProjectA = !ExpressionUtils.isIntersecting(projectOutputSet,
aOutputSet);
- boolean isProjectB = !ExpressionUtils.isIntersecting(projectOutputSet,
bOutputSet);
+ List<Slot> aOutput = topSemiJoin.left().child().left().getOutput();
+ List<Slot> bOutput = topSemiJoin.left().child().right().getOutput();
- Preconditions.checkState(isProjectA || isProjectB, "project must
contain child");
- return !(isProjectA && isProjectB);
+ boolean hashContainsA = false;
+ boolean hashContainsB = false;
+ for (Expression hashJoinConjunct : hashJoinConjuncts) {
+ Set<Slot> usedSlot =
hashJoinConjunct.collect(Slot.class::isInstance);
+ hashContainsA = ExpressionUtils.isIntersecting(usedSlot, aOutput)
|| hashContainsA;
+ hashContainsB = ExpressionUtils.isIntersecting(usedSlot, bOutput)
|| hashContainsB;
+ }
+ if (leftDeep && hashContainsB) {
+ return false;
+ }
+ Preconditions.checkState(hashContainsA || hashContainsB, "join output
must contain child");
+ return !(hashContainsA && hashContainsB);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java
index 31d326612d..beab255b89 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java
@@ -37,6 +37,7 @@ import java.util.Set;
* LEFT-Semi/ANTI(X, LEFT-Semi/ANTI(Y, Z))
*/
public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory {
+ public static final SemiJoinSemiJoinTranspose INSTANCE = new
SemiJoinSemiJoinTranspose();
public static Set<Pair<JoinType, JoinType>> typeSet = ImmutableSet.of(
Pair.of(JoinType.LEFT_SEMI_JOIN, JoinType.LEFT_SEMI_JOIN),
@@ -69,7 +70,7 @@ public class SemiJoinSemiJoinTranspose extends
OneExplorationRuleFactory {
newBottomJoin, b);
return newTopJoin;
- }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
+ }).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE);
}
private boolean typeChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>,
GroupPlan> topJoin) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimit.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimit.java
index 230e5f2e98..bd56224d26 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimit.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimit.java
@@ -54,6 +54,6 @@ public class PushdownProjectThroughLimit extends
OneRewriteRuleFactory {
return new
LogicalLimit<LogicalProject<GroupPlan>>(logicalLimit.getLimit(),
logicalLimit.getOffset(), new
LogicalProject<>(logicalProject.getProjects(),
logicalLimit.child()));
- }).toRule(RuleType.SWAP_LIMIT_PROJECT);
+ }).toRule(RuleType.PUSHDOWN_PROJECT_THROUGHT_LIMIT);
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java
new file mode 100644
index 0000000000..8111f5b92e
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java
@@ -0,0 +1,134 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration.join;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.memo.Group;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+public class SemiJoinLogicalJoinTransposeProjectTest {
+ private static final LogicalOlapScan scan1 =
PlanConstructor.newLogicalOlapScan(0, "t1", 0);
+ private static final LogicalOlapScan scan2 =
PlanConstructor.newLogicalOlapScan(1, "t2", 0);
+ private static final LogicalOlapScan scan3 =
PlanConstructor.newLogicalOlapScan(2, "t3", 0);
+
+ @Test
+ public void testSemiJoinLogicalTransposeProjectLAsscom() {
+ /*-
+ * topSemiJoin project
+ * / \ |
+ * project C newTopJoin
+ * | -> / \
+ * bottomJoin newBottomSemiJoin B
+ * / \ / \
+ * A B A C
+ */
+ LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
+ .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) //
t1.id = t2.id
+ .project(ImmutableList.of(0))
+ .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
// t1.id = t3.id
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
+
.transform(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP.build())
+ .checkMemo(memo -> {
+ Group root = memo.getRoot();
+ Assertions.assertEquals(2,
root.getLogicalExpressions().size());
+ Plan plan =
memo.copyOut(root.getLogicalExpressions().get(1), false);
+
+ LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>)
plan.child(0);
+ LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>)
newTopJoin.left();
+ Assertions.assertEquals(JoinType.INNER_JOIN,
newTopJoin.getJoinType());
+ Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN,
newBottomJoin.getJoinType());
+
+ LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan)
newBottomJoin.left();
+ LogicalOlapScan newBottomJoinRight = (LogicalOlapScan)
newBottomJoin.right();
+ LogicalOlapScan newTopJoinRight = (LogicalOlapScan)
newTopJoin.right();
+
+ Assertions.assertEquals("t1",
newBottomJoinLeft.getTable().getName());
+ Assertions.assertEquals("t3",
newBottomJoinRight.getTable().getName());
+ Assertions.assertEquals("t2",
newTopJoinRight.getTable().getName());
+ });
+ }
+
+ @Test
+ public void testSemiJoinLogicalTransposeProjectLAsscomFail() {
+ LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
+ .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) //
t1.id = t2.id
+ .project(ImmutableList.of(0, 2)) // t1.id, t2.id
+ .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(1, 0))
// t2.id = t3.id
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
+
.transform(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP.build())
+ .checkMemo(memo -> {
+ Group root = memo.getRoot();
+ Assertions.assertEquals(1,
root.getLogicalExpressions().size());
+ });
+ }
+
+ @Test
+ public void testSemiJoinLogicalTransposeProjectAll() {
+ /*-
+ * topSemiJoin project
+ * / \ |
+ * project C newTopJoin
+ * | / \
+ * bottomJoin C --> A newBottomSemiJoin
+ * / \ / \
+ * A B B C
+ */
+ LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
+ .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) //
t1.id = t2.id
+ .project(ImmutableList.of(0, 2)) // t1.id, t2.id
+ .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(1, 0))
// t2.id = t3.id
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
+ .transform(SemiJoinLogicalJoinTransposeProject.ALL.build())
+ .checkMemo(memo -> {
+ Group root = memo.getRoot();
+ Assertions.assertEquals(2,
root.getLogicalExpressions().size());
+ Plan plan =
memo.copyOut(root.getLogicalExpressions().get(1), false);
+
+ LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>)
plan.child(0);
+ LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>)
newTopJoin.right();
+ Assertions.assertEquals(JoinType.INNER_JOIN,
newTopJoin.getJoinType());
+ Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN,
newBottomJoin.getJoinType());
+
+ LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan)
newBottomJoin.left();
+ LogicalOlapScan newBottomJoinRight = (LogicalOlapScan)
newBottomJoin.right();
+ LogicalOlapScan newTopJoinLeft = (LogicalOlapScan)
newTopJoin.left();
+
+ Assertions.assertEquals("t1",
newTopJoinLeft.getTable().getName());
+ Assertions.assertEquals("t2",
newBottomJoinLeft.getTable().getName());
+ Assertions.assertEquals("t3",
newBottomJoinRight.getTable().getName());
+ });
+ }
+}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeTest.java
new file mode 100644
index 0000000000..29ba945dbb
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeTest.java
@@ -0,0 +1,126 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration.join;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.memo.Group;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+public class SemiJoinLogicalJoinTransposeTest {
+ private static final LogicalOlapScan scan1 =
PlanConstructor.newLogicalOlapScan(0, "t1", 0);
+ private static final LogicalOlapScan scan2 =
PlanConstructor.newLogicalOlapScan(1, "t2", 0);
+ private static final LogicalOlapScan scan3 =
PlanConstructor.newLogicalOlapScan(2, "t3", 0);
+
+ @Test
+ public void testSemiJoinLogicalTransposeLAsscom() {
+ /*
+ * topSemiJoin newTopJoin
+ * / \ / \
+ * bottomJoin C --> newBottomSemiJoin B
+ * / \ / \
+ * A B A C
+ */
+ LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
+ .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) //
t1.id = t2.id
+ .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
// t1.id = t3.id
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
+ .transform(SemiJoinLogicalJoinTranspose.LEFT_DEEP.build())
+ .checkMemo(memo -> {
+ Group root = memo.getRoot();
+ Assertions.assertEquals(2,
root.getLogicalExpressions().size());
+ Plan plan =
memo.copyOut(root.getLogicalExpressions().get(1), false);
+
+ LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan;
+ LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>)
newTopJoin.left();
+ Assertions.assertEquals(JoinType.INNER_JOIN,
newTopJoin.getJoinType());
+ Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN,
newBottomJoin.getJoinType());
+
+ LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan)
newBottomJoin.left();
+ LogicalOlapScan newBottomJoinRight = (LogicalOlapScan)
newBottomJoin.right();
+ LogicalOlapScan newTopJoinRight = (LogicalOlapScan)
newTopJoin.right();
+
+ Assertions.assertEquals("t1",
newBottomJoinLeft.getTable().getName());
+ Assertions.assertEquals("t3",
newBottomJoinRight.getTable().getName());
+ Assertions.assertEquals("t2",
newTopJoinRight.getTable().getName());
+ });
+ }
+
+ @Test
+ public void testSemiJoinLogicalTransposeLAsscomFail() {
+ LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
+ .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) //
t1.id = t2.id
+ .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(2, 0))
// t2.id = t3.id
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
+ .transform(SemiJoinLogicalJoinTranspose.LEFT_DEEP.build())
+ .checkMemo(memo -> {
+ Group root = memo.getRoot();
+ Assertions.assertEquals(1,
root.getLogicalExpressions().size());
+ });
+ }
+
+ @Test
+ public void testSemiJoinLogicalTransposeAll() {
+ /*
+ * topSemiJoin newTopJoin
+ * / \ / \
+ * bottomJoin C --> A newBottomSemiJoin
+ * / \ / \
+ * A B B C
+ */
+ LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
+ .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) //
t1.id = t2.id
+ .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(2, 0))
// t2.id = t3.id
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
+ .transform(SemiJoinLogicalJoinTranspose.ALL.build())
+ .checkMemo(memo -> {
+ Group root = memo.getRoot();
+ Assertions.assertEquals(2,
root.getLogicalExpressions().size());
+ Plan plan =
memo.copyOut(root.getLogicalExpressions().get(1), false);
+
+ LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan;
+ LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>)
newTopJoin.right();
+ Assertions.assertEquals(JoinType.INNER_JOIN,
newTopJoin.getJoinType());
+ Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN,
newBottomJoin.getJoinType());
+
+ LogicalOlapScan newTopJoinLeft = (LogicalOlapScan)
newTopJoin.left();
+ LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan)
newBottomJoin.left();
+ LogicalOlapScan newBottomJoinRight = (LogicalOlapScan)
newBottomJoin.right();
+
+ Assertions.assertEquals("t1",
newTopJoinLeft.getTable().getName());
+ Assertions.assertEquals("t2",
newBottomJoinLeft.getTable().getName());
+ Assertions.assertEquals("t3",
newBottomJoinRight.getTable().getName());
+ });
+ }
+}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinTransposeTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeTest.java
similarity index 83%
rename from
fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinTransposeTest.java
rename to
fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeTest.java
index c7aa852449..3091d44613 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinTransposeTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeTest.java
@@ -29,11 +29,10 @@ import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
-import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
-public class SemiJoinTransposeTest {
+public class SemiJoinSemiJoinTransposeTest {
public static final LogicalOlapScan scan1 =
PlanConstructor.newLogicalOlapScan(0, "t1", 0);
public static final LogicalOlapScan scan2 =
PlanConstructor.newLogicalOlapScan(1, "t2", 0);
public static final LogicalOlapScan scan3 =
PlanConstructor.newLogicalOlapScan(2, "t3", 0);
@@ -41,21 +40,19 @@ public class SemiJoinTransposeTest {
@Test
public void testSemiJoinLogicalTransposeCommute() {
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
- .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
- .project(ImmutableList.of(0))
+ .hashJoinUsing(scan2, JoinType.LEFT_ANTI_JOIN, Pair.of(0, 0))
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
- .transform((new SemiJoinLogicalJoinTransposeProject()).build())
+ .transform(SemiJoinSemiJoinTranspose.INSTANCE.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(2,
root.getLogicalExpressions().size());
- Plan plan =
memo.copyOut(root.getLogicalExpressions().get(1), false);
+ Plan join =
memo.copyOut(root.getLogicalExpressions().get(1), false);
- Plan join = plan.child(0);
Assertions.assertTrue(join instanceof LogicalJoin);
- Assertions.assertEquals(JoinType.INNER_JOIN,
((LogicalJoin<?, ?>) join).getJoinType());
+ Assertions.assertEquals(JoinType.LEFT_ANTI_JOIN,
((LogicalJoin<?, ?>) join).getJoinType());
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN,
((LogicalJoin<?, ?>) ((LogicalJoin<?, ?>)
join).left()).getJoinType());
});
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
index 950647a674..29ea98d708 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
@@ -58,10 +58,10 @@ public class LogicalPlanBuilder {
return from(project);
}
- public LogicalPlanBuilder project(List<Integer> slots) {
+ public LogicalPlanBuilder project(List<Integer> slotsIndex) {
List<NamedExpression> projectExprs = Lists.newArrayList();
- for (int i = 0; i < slots.size(); i++) {
- projectExprs.add(this.plan.getOutput().get(i));
+ for (Integer index : slotsIndex) {
+ projectExprs.add(this.plan.getOutput().get(index));
}
LogicalProject<LogicalPlan> project = new
LogicalProject<>(projectExprs, this.plan);
return from(project);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]