This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 7c7ac86fe8 [feature](Nereids): Left deep tree join order. (#12439)
7c7ac86fe8 is described below
commit 7c7ac86fe8e3360e39b96d2c517aa25b61b18426
Author: jakevin <[email protected]>
AuthorDate: Thu Sep 8 15:09:22 2022 +0800
[feature](Nereids): Left deep tree join order. (#12439)
* [feature](Nereids): Left deep tree join order.
---
.../org/apache/doris/nereids/rules/RuleSet.java | 44 +++-
.../rules/exploration/join/JoinCommute.java | 74 +++---
.../rules/exploration/join/JoinCommuteProject.java | 66 ------
.../rules/exploration/join/JoinLAsscom.java | 50 +++--
.../rules/exploration/join/JoinLAsscomHelper.java | 250 +++++----------------
.../rules/exploration/join/JoinLAsscomProject.java | 55 +++--
...inCommuteHelper.java => JoinReorderCommon.java} | 36 ++-
.../exploration/{ => join}/JoinReorderContext.java | 19 +-
.../rules/exploration/join/ThreeJoinHelper.java | 165 ++++++++++++++
.../nereids/trees/plans/logical/LogicalJoin.java | 2 +-
.../apache/doris/nereids/util/ExpressionUtils.java | 22 ++
.../rules/exploration/join/JoinCommuteTest.java | 30 ++-
.../exploration/join/JoinLAsscomProjectTest.java | 134 -----------
.../rules/exploration/join/JoinLAsscomTest.java | 156 +++++--------
.../org/apache/doris/nereids/util/PlanChecker.java | 37 ++-
15 files changed, 548 insertions(+), 592 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
index fb2696788c..936834605b 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
@@ -18,7 +18,8 @@
package org.apache.doris.nereids.rules;
import org.apache.doris.nereids.rules.exploration.join.JoinCommute;
-import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject;
+import org.apache.doris.nereids.rules.exploration.join.JoinLAsscom;
+import org.apache.doris.nereids.rules.exploration.join.JoinLAsscomProject;
import
org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg;
import
org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
import
org.apache.doris.nereids.rules.implementation.LogicalEmptyRelationToPhysicalEmptyRelation;
@@ -32,6 +33,7 @@ import
org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalPro
import
org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort;
import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
+import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
@@ -43,8 +45,10 @@ import java.util.List;
*/
public class RuleSet {
public static final List<Rule> EXPLORATION_RULES = planRuleFactories()
- .add(JoinCommute.SWAP_OUTER_SWAP_ZIG_ZAG)
- .add(JoinCommuteProject.SWAP_OUTER_SWAP_ZIG_ZAG)
+ .add(JoinCommute.OUTER_LEFT_DEEP)
+ .add(JoinLAsscom.INNER)
+ .add(JoinLAsscomProject.INNER)
+ .add(new MergeConsecutiveProjects())
.build();
public static final List<Rule> REWRITE_RULES = planRuleFactories()
@@ -66,6 +70,40 @@ public class RuleSet {
.add(new LogicalEmptyRelationToPhysicalEmptyRelation())
.build();
+ public static final List<Rule> LEFT_DEEP_TREE_JOIN_REORDER =
planRuleFactories()
+ .add(JoinCommute.OUTER_LEFT_DEEP)
+ .add(JoinLAsscom.INNER)
+ .add(JoinLAsscomProject.INNER)
+ .add(JoinLAsscom.OUTER)
+ .add(JoinLAsscomProject.OUTER)
+ // semi join Transpose ....
+ .build();
+
+ public static final List<Rule> ZIG_ZAG_TREE_JOIN_REORDER =
planRuleFactories()
+ .add(JoinCommute.OUTER_ZIG_ZAG)
+ .add(JoinLAsscom.INNER)
+ .add(JoinLAsscomProject.INNER)
+ .add(JoinLAsscom.OUTER)
+ .add(JoinLAsscomProject.OUTER)
+ // semi join Transpose ....
+ .build();
+
+ public static final List<Rule> BUSHY_TREE_JOIN_REORDER =
planRuleFactories()
+ .add(JoinCommute.OUTER_BUSHY)
+ // TODO: add more rule
+ // .add(JoinLeftAssociate.INNER)
+ // .add(JoinLeftAssociateProject.INNER)
+ // .add(JoinRightAssociate.INNER)
+ // .add(JoinRightAssociateProject.INNER)
+ // .add(JoinExchange.INNER)
+ // .add(JoinExchangeBothProject.INNER)
+ // .add(JoinExchangeLeftProject.INNER)
+ // .add(JoinExchangeRightProject.INNER)
+ // .add(JoinRightAssociate.OUTER)
+ .add(JoinLAsscom.OUTER)
+ // semi join Transpose ....
+ .build();
+
public List<Rule> getExplorationRules() {
return EXPLORATION_RULES;
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java
index 64ebe5171f..129b655e5f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java
@@ -17,54 +17,72 @@
package org.apache.doris.nereids.rules.exploration.join;
-import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
-import
org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.util.Utils;
+
+import java.util.ArrayList;
+import java.util.List;
/**
* Join Commute
*/
-@Developing
public class JoinCommute extends OneExplorationRuleFactory {
- public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new
JoinCommute(true, SwapType.BOTTOM_JOIN);
- public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new
JoinCommute(true, SwapType.ZIG_ZAG);
+ public static final JoinCommute OUTER_LEFT_DEEP = new
JoinCommute(SwapType.LEFT_DEEP);
+ public static final JoinCommute OUTER_ZIG_ZAG = new
JoinCommute(SwapType.ZIG_ZAG);
+ public static final JoinCommute OUTER_BUSHY = new
JoinCommute(SwapType.BUSHY);
- private final boolean swapOuter;
private final SwapType swapType;
- public JoinCommute(boolean swapOuter) {
- this.swapOuter = swapOuter;
- this.swapType = SwapType.ALL;
+ public JoinCommute(SwapType swapType) {
+ this.swapType = swapType;
}
- public JoinCommute(boolean swapOuter, SwapType swapType) {
- this.swapOuter = swapOuter;
- this.swapType = swapType;
+ enum SwapType {
+ LEFT_DEEP, ZIG_ZAG, BUSHY
}
@Override
public Rule build() {
- return innerLogicalJoin().when(JoinCommuteHelper::check).then(join -> {
- // TODO: add project for mapping column output.
- // List<NamedExpression> newOutput = new
ArrayList<>(join.getOutput());
- LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
- join.getJoinType(),
- join.getHashJoinConjuncts(),
- join.getOtherJoinCondition(),
- join.right(), join.left(),
- join.getJoinReorderContext());
- newJoin.getJoinReorderContext().setHasCommute(true);
- // if (swapType == SwapType.ZIG_ZAG && !isBottomJoin(join)) {
- // newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
- // }
+ return innerLogicalJoin()
+ .when(this::check)
+ .then(join -> {
+ LogicalJoin<GroupPlan, GroupPlan> newJoin = new
LogicalJoin<>(
+ join.getJoinType(),
+ join.getHashJoinConjuncts(),
+ join.getOtherJoinCondition(),
+ join.right(), join.left(),
+ join.getJoinReorderContext());
+ newJoin.getJoinReorderContext().setHasCommute(true);
+ if (swapType == SwapType.ZIG_ZAG && isNotBottomJoin(join))
{
+
newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
+ }
+
+ return JoinReorderCommon.project(new
ArrayList<>(join.getOutput()), newJoin).get();
+ }).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
+ }
+
+ private boolean check(LogicalJoin<GroupPlan, GroupPlan> join) {
+ if (swapType == SwapType.LEFT_DEEP && isNotBottomJoin(join)) {
+ return false;
+ }
+
+ return !join.getJoinReorderContext().hasCommute() &&
!join.getJoinReorderContext().hasExchange();
+ }
+
+ private boolean isNotBottomJoin(LogicalJoin<GroupPlan, GroupPlan> join) {
+ // TODO: tmp way to judge bottomJoin
+ return containJoin(join.left()) || containJoin(join.right());
+ }
- // LogicalProject<LogicalJoin> project = new
LogicalProject<>(newOutput, newJoin);
- return newJoin;
- }).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
+ private boolean containJoin(GroupPlan groupPlan) {
+ // TODO: tmp way to judge containJoin
+ List<SlotReference> output = Utils.getOutputSlotReference(groupPlan);
+ return
!output.stream().map(SlotReference::getQualifier).allMatch(output.get(0).getQualifier()::equals);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteProject.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteProject.java
deleted file mode 100644
index 07464275a1..0000000000
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteProject.java
+++ /dev/null
@@ -1,66 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.exploration.join;
-
-import org.apache.doris.nereids.rules.Rule;
-import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
-import
org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
-import org.apache.doris.nereids.trees.plans.GroupPlan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-
-/**
- * Project-Join commute
- */
-public class JoinCommuteProject extends OneExplorationRuleFactory {
-
- public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new
JoinCommute(true, SwapType.BOTTOM_JOIN);
- public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new
JoinCommute(true, SwapType.ZIG_ZAG);
-
- private final SwapType swapType;
- private final boolean swapOuter;
-
- public JoinCommuteProject(boolean swapOuter) {
- this.swapOuter = swapOuter;
- this.swapType = SwapType.ALL;
- }
-
- public JoinCommuteProject(boolean swapOuter, SwapType swapType) {
- this.swapOuter = swapOuter;
- this.swapType = swapType;
- }
-
- @Override
- public Rule build() {
- return
logicalProject(innerLogicalJoin()).when(JoinCommuteHelper::check).then(project
-> {
- LogicalJoin<GroupPlan, GroupPlan> join = project.child();
- LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
- join.getJoinType(),
- join.getHashJoinConjuncts(),
- join.getOtherJoinCondition(),
- join.right(), join.left(),
- join.getJoinReorderContext());
- newJoin.getJoinReorderContext().setHasCommute(true);
- // if (swapType == SwapType.ZIG_ZAG && !isBottomJoin(join)) {
- // newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
- // }
-
- return newJoin;
- }).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
- }
-}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscom.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscom.java
index 65f849cd5d..07d8acaceb 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscom.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscom.java
@@ -17,18 +17,42 @@
package org.apache.doris.nereids.rules.exploration.join;
-import org.apache.doris.nereids.annotation.Developing;
+import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
+import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import java.util.function.Predicate;
+
/**
* Rule for change inner join LAsscom (associative and commutive).
*/
-@Developing
public class JoinLAsscom extends OneExplorationRuleFactory {
+ // for inner-inner
+ public static final JoinLAsscom INNER = new JoinLAsscom(Type.INNER);
+ // for inner-leftOuter or leftOuter-leftOuter
+ public static final JoinLAsscom OUTER = new JoinLAsscom(Type.OUTER);
+
+ private final Predicate<LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>,
GroupPlan>> typeChecker;
+
+ private final Type type;
+
+ /**
+ * Specify join type.
+ */
+ public JoinLAsscom(Type type) {
+ this.type = type;
+ if (type == Type.INNER) {
+ typeChecker = join -> join.getJoinType().isInnerJoin() &&
join.left().getJoinType().isInnerJoin();
+ } else {
+ typeChecker = join -> JoinLAsscomHelper.outerSet.contains(
+ Pair.of(join.left().getJoinType(), join.getJoinType()));
+ }
+ }
+
/*
* topJoin newTopJoin
* / \ / \
@@ -39,18 +63,14 @@ public class JoinLAsscom extends OneExplorationRuleFactory {
@Override
public Rule build() {
return logicalJoin(logicalJoin(), group())
- .when(JoinLAsscomHelper::check)
- .when(join -> join.getJoinType().isInnerJoin() ||
join.getJoinType().isLeftOuterJoin()
- && (join.left().getJoinType().isInnerJoin() ||
join.left().getJoinType().isLeftOuterJoin()))
- .then(topJoin -> {
-
- LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
- JoinLAsscomHelper helper = JoinLAsscomHelper.of(topJoin,
bottomJoin);
- if (!helper.initJoinOnCondition()) {
- return null;
- }
-
- return helper.newTopJoin();
- }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
+ .when(topJoin -> JoinLAsscomHelper.check(type, topJoin,
topJoin.left()))
+ .when(typeChecker)
+ .then(topJoin -> {
+ JoinLAsscomHelper helper = new JoinLAsscomHelper(topJoin,
topJoin.left());
+ if (!helper.initJoinOnCondition()) {
+ return null;
+ }
+ return helper.newTopJoin();
+ }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java
index e6fe676406..ac31083bde 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java
@@ -18,29 +18,27 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
-import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
-import com.google.common.base.Preconditions;
-import com.google.common.collect.Lists;
+import com.google.common.collect.ImmutableSet;
-import java.util.HashSet;
+import java.util.ArrayList;
import java.util.List;
-import java.util.Optional;
import java.util.Set;
+import java.util.stream.Collectors;
/**
* Common function for JoinLAsscom
*/
-public class JoinLAsscomHelper {
+class JoinLAsscomHelper extends ThreeJoinHelper {
/*
* topJoin newTopJoin
* / \ / \
@@ -48,209 +46,79 @@ public class JoinLAsscomHelper {
* / \ / \
* A B A C
*/
- private final LogicalJoin topJoin;
- private final LogicalJoin<GroupPlan, GroupPlan> bottomJoin;
- private final Plan a;
- private final Plan b;
- private final Plan c;
- private final List<Expression> topHashJoinConjuncts;
- private final List<Expression> bottomHashJoinConjuncts;
- private final List<Expression> allNonHashJoinConjuncts =
Lists.newArrayList();
- private final List<SlotReference> aOutputSlots;
- private final List<SlotReference> bOutputSlots;
- private final List<SlotReference> cOutputSlots;
-
- private final List<Expression> newBottomHashJoinConjuncts =
Lists.newArrayList();
- private final List<Expression> newBottomNonHashJoinConjuncts =
Lists.newArrayList();
-
- private final List<Expression> newTopHashJoinConjuncts =
Lists.newArrayList();
- private final List<Expression> newTopNonHashJoinConjuncts =
Lists.newArrayList();
+ // Pair<bottomJoin, topJoin>
+ // newBottomJoin Type = topJoin Type, newTopJoin Type = bottomJoin Type
+ public static Set<Pair<JoinType, JoinType>> outerSet = ImmutableSet.of(
+ Pair.of(JoinType.LEFT_OUTER_JOIN, JoinType.INNER_JOIN),
+ Pair.of(JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN),
+ Pair.of(JoinType.LEFT_OUTER_JOIN, JoinType.LEFT_OUTER_JOIN));
/**
* Init plan and output.
*/
public JoinLAsscomHelper(LogicalJoin<? extends Plan, GroupPlan> topJoin,
LogicalJoin<GroupPlan, GroupPlan> bottomJoin) {
- this.topJoin = topJoin;
- this.bottomJoin = bottomJoin;
-
- a = bottomJoin.left();
- b = bottomJoin.right();
- c = topJoin.right();
-
- Preconditions.checkArgument(!topJoin.getHashJoinConjuncts().isEmpty(),
- "topJoin hashJoinConjuncts must exist.");
- topHashJoinConjuncts = topJoin.getHashJoinConjuncts();
- if (topJoin.getOtherJoinCondition().isPresent()) {
- allNonHashJoinConjuncts.addAll(
-
ExpressionUtils.extractConjunction(topJoin.getOtherJoinCondition().get()));
- }
-
Preconditions.checkArgument(!bottomJoin.getHashJoinConjuncts().isEmpty(),
- "bottomJoin onClause must exist.");
- bottomHashJoinConjuncts = bottomJoin.getHashJoinConjuncts();
- if (bottomJoin.getOtherJoinCondition().isPresent()) {
- allNonHashJoinConjuncts.addAll(
-
ExpressionUtils.extractConjunction(bottomJoin.getOtherJoinCondition().get()));
- }
-
- aOutputSlots = Utils.getOutputSlotReference(a);
- bOutputSlots = Utils.getOutputSlotReference(b);
- cOutputSlots = Utils.getOutputSlotReference(c);
- }
-
- public static JoinLAsscomHelper of(LogicalJoin<? extends Plan, GroupPlan>
topJoin,
- LogicalJoin<GroupPlan, GroupPlan> bottomJoin) {
- return new JoinLAsscomHelper(topJoin, bottomJoin);
- }
-
- /**
- * Get the onCondition of newTopJoin and newBottomJoin.
- */
- public boolean initJoinOnCondition() {
- for (Expression topJoinOnClauseConjunct : topHashJoinConjuncts) {
- // Ignore join with some OnClause like:
- // Join C = B + A for above example.
- Set<Slot> topJoinUsedSlot =
topJoinOnClauseConjunct.getInputSlots();
- if (topJoinUsedSlot.containsAll(aOutputSlots)
- && topJoinUsedSlot.containsAll(bOutputSlots)
- && topJoinUsedSlot.containsAll(cOutputSlots)) {
- return false;
- }
- }
-
- List<Expression> allHashJoinConjuncts = Lists.newArrayList();
- allHashJoinConjuncts.addAll(topHashJoinConjuncts);
- allHashJoinConjuncts.addAll(bottomHashJoinConjuncts);
-
- Set<Slot> newBottomJoinSlots = new HashSet<>(aOutputSlots);
- newBottomJoinSlots.addAll(cOutputSlots);
-
- for (Expression hashConjunct : allHashJoinConjuncts) {
- Set<Slot> slots = hashConjunct.getInputSlots();
- if (newBottomJoinSlots.containsAll(slots)) {
- newBottomHashJoinConjuncts.add(hashConjunct);
- } else {
- newTopHashJoinConjuncts.add(hashConjunct);
- }
- }
- for (Expression nonHashConjunct : allNonHashJoinConjuncts) {
- Set<SlotReference> slots =
nonHashConjunct.collect(SlotReference.class::isInstance);
- if (newBottomJoinSlots.containsAll(slots)) {
- newBottomNonHashJoinConjuncts.add(nonHashConjunct);
- } else {
- newTopNonHashJoinConjuncts.add(nonHashConjunct);
- }
- }
- // newBottomJoinOnCondition/newTopJoinOnCondition is empty. They are
cross join.
- // Example:
- // A: col1, col2. B: col2, col3. C: col3, col4
- // (A & B on A.col2=B.col2) & C on B.col3=C.col3.
- // (A & B) & C -> (A & C) & B.
- // (A & C) will be cross join (newBottomJoinOnCondition is empty)
- if (newBottomHashJoinConjuncts.isEmpty() ||
newTopHashJoinConjuncts.isEmpty()) {
- return false;
- }
-
- return true;
+ super(topJoin, bottomJoin, bottomJoin.left(), bottomJoin.right(),
topJoin.right());
}
/**
- * Get projectExpr of left and right.
- * Just for project-inside.
+ * Create newTopJoin.
*/
- private Pair<List<NamedExpression>, List<NamedExpression>>
getProjectExprs() {
- Preconditions.checkArgument(topJoin.left() instanceof LogicalProject);
- LogicalProject project = (LogicalProject) topJoin.left();
-
- List<NamedExpression> projectExprs = project.getProjects();
- List<NamedExpression> newRightProjectExprs = Lists.newArrayList();
- List<NamedExpression> newLeftProjectExpr = Lists.newArrayList();
-
- HashSet<SlotReference> bOutputSlotsSet = new HashSet<>(bOutputSlots);
- for (NamedExpression projectExpr : projectExprs) {
- Set<SlotReference> usedSlotRefs =
projectExpr.collect(SlotReference.class::isInstance);
- if (bOutputSlotsSet.containsAll(usedSlotRefs)) {
- newRightProjectExprs.add(projectExpr);
- } else {
- newLeftProjectExpr.add(projectExpr);
+ public Plan newTopJoin() {
+ Pair<List<NamedExpression>, List<NamedExpression>> projectPair =
splitProjectExprs(bOutput);
+ List<NamedExpression> newLeftProjectExpr = projectPair.second;
+ List<NamedExpression> newRightProjectExprs = projectPair.first;
+
+ // If add project to B, we should add all slotReference used by
hashOnCondition.
+ // TODO: Does nonHashOnCondition also need to be considered.
+ Set<SlotReference> onUsedSlotRef =
bottomJoin.getHashJoinConjuncts().stream()
+ .flatMap(expr -> {
+ Set<SlotReference> usedSlotRefs =
expr.collect(SlotReference.class::isInstance);
+ return usedSlotRefs.stream();
+
}).filter(Utils.getOutputSlotReference(bottomJoin)::contains).collect(Collectors.toSet());
+ boolean existRightProject = !newRightProjectExprs.isEmpty();
+ boolean existLeftProject = !newLeftProjectExpr.isEmpty();
+ onUsedSlotRef.forEach(slotRef -> {
+ if (existRightProject && bOutput.contains(slotRef) &&
!newRightProjectExprs.contains(slotRef)) {
+ newRightProjectExprs.add(slotRef);
+ } else if (existLeftProject && aOutput.contains(slotRef) &&
!newLeftProjectExpr.contains(slotRef)) {
+ newLeftProjectExpr.add(slotRef);
}
- }
-
- return Pair.of(newLeftProjectExpr, newRightProjectExprs);
- }
+ });
- private LogicalJoin<GroupPlan, GroupPlan> newBottomJoin() {
- Optional<Expression> bottomNonHashExpr;
- if (newBottomNonHashJoinConjuncts.isEmpty()) {
- bottomNonHashExpr = Optional.empty();
- } else {
- bottomNonHashExpr =
Optional.of(ExpressionUtils.and(newBottomNonHashJoinConjuncts));
+ if (existLeftProject) {
+ newLeftProjectExpr.addAll(cOutput);
}
- return new LogicalJoin(
- bottomJoin.getJoinType(),
- newBottomHashJoinConjuncts,
- bottomNonHashExpr,
- a, c);
- }
+ LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new
LogicalJoin<>(topJoin.getJoinType(),
+ newBottomHashJoinConjuncts,
ExpressionUtils.andByOptional(newBottomNonHashJoinConjuncts), a, c,
+ bottomJoin.getJoinReorderContext());
+ newBottomJoin.getJoinReorderContext().setHasLAsscom(false);
+ newBottomJoin.getJoinReorderContext().setHasCommute(false);
- /**
- * Create topJoin for project-inside.
- */
- public LogicalJoin newProjectTopJoin() {
- Plan left;
- Plan right;
+ Plan left = JoinReorderCommon.project(newLeftProjectExpr,
newBottomJoin).orElse(newBottomJoin);
+ Plan right = JoinReorderCommon.project(newRightProjectExprs,
b).orElse(b);
- List<NamedExpression> newLeftProjectExpr = getProjectExprs().first;
- List<NamedExpression> newRightProjectExprs = getProjectExprs().second;
- if (!newLeftProjectExpr.isEmpty()) {
- left = new LogicalProject<>(newLeftProjectExpr, newBottomJoin());
- } else {
- left = newBottomJoin();
- }
- if (!newRightProjectExprs.isEmpty()) {
- right = new LogicalProject<>(newRightProjectExprs, b);
- } else {
- right = b;
- }
- Optional<Expression> topNonHashExpr;
- if (newTopNonHashJoinConjuncts.isEmpty()) {
- topNonHashExpr = Optional.empty();
- } else {
- topNonHashExpr =
Optional.of(ExpressionUtils.and(newTopNonHashJoinConjuncts));
- }
- return new LogicalJoin<>(
- topJoin.getJoinType(),
+ LogicalJoin<Plan, Plan> newTopJoin = new
LogicalJoin<>(bottomJoin.getJoinType(),
newTopHashJoinConjuncts,
- topNonHashExpr,
- left, right);
- }
+ ExpressionUtils.andByOptional(newTopNonHashJoinConjuncts),
left, right,
+ topJoin.getJoinReorderContext());
+ newTopJoin.getJoinReorderContext().setHasLAsscom(true);
- /**
- * Create topJoin for no-project-inside.
- */
- public LogicalJoin newTopJoin() {
- // TODO: add column map (use project)
- // SlotReference bind() may have solved this problem.
- // source: | A | B | C |
- // target: | A | C | B |
- Optional<Expression> topNonHashExpr;
- if (newTopNonHashJoinConjuncts.isEmpty()) {
- topNonHashExpr = Optional.empty();
- } else {
- topNonHashExpr =
Optional.of(ExpressionUtils.and(newTopNonHashJoinConjuncts));
- }
- return new LogicalJoin(
- topJoin.getJoinType(),
- newTopHashJoinConjuncts,
- topNonHashExpr,
- newBottomJoin(), b);
+ return JoinReorderCommon.project(new ArrayList<>(topJoin.getOutput()),
newTopJoin).get();
}
- public static boolean check(LogicalJoin topJoin) {
- if (topJoin.getJoinReorderContext().hasCommute()) {
- return false;
+ public static boolean check(Type type, LogicalJoin<? extends Plan,
GroupPlan> topJoin,
+ LogicalJoin<GroupPlan, GroupPlan> bottomJoin) {
+ if (type == Type.INNER) {
+ return !bottomJoin.getJoinReorderContext().hasCommuteZigZag()
+ && !topJoin.getJoinReorderContext().hasLAsscom();
+ } else {
+ // hasCommute will cause to lack of OuterJoinAssocRule:Left
+ return !topJoin.getJoinReorderContext().hasLeftAssociate()
+ && !topJoin.getJoinReorderContext().hasRightAssociate()
+ && !topJoin.getJoinReorderContext().hasExchange()
+ && !bottomJoin.getJoinReorderContext().hasCommute();
}
- return true;
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java
index 9876ed29fd..5bbd120b52 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java
@@ -17,18 +17,43 @@
package org.apache.doris.nereids.rules.exploration.join;
-import org.apache.doris.nereids.annotation.Developing;
+import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
+import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+
+import java.util.function.Predicate;
/**
- * Rule for change inner join left associative to right.
+ * Rule for change inner join LAsscom (associative and commutive).
*/
-@Developing
public class JoinLAsscomProject extends OneExplorationRuleFactory {
+ // for inner-inner
+ public static final JoinLAsscomProject INNER = new
JoinLAsscomProject(Type.INNER);
+ // for inner-leftOuter or leftOuter-leftOuter
+ public static final JoinLAsscomProject OUTER = new
JoinLAsscomProject(Type.OUTER);
+
+ private final Predicate<LogicalJoin<LogicalProject<LogicalJoin<GroupPlan,
GroupPlan>>, GroupPlan>> typeChecker;
+
+ private final Type type;
+
+ /**
+ * Specify join type.
+ */
+ public JoinLAsscomProject(Type type) {
+ this.type = type;
+ if (type == Type.INNER) {
+ typeChecker = join -> join.getJoinType().isInnerJoin() &&
join.left().child().getJoinType().isInnerJoin();
+ } else {
+ typeChecker = join -> JoinLAsscomHelper.outerSet.contains(
+ Pair.of(join.left().child().getJoinType(),
join.getJoinType()));
+ }
+ }
+
/*
* topJoin newTopJoin
* / \ / \
@@ -41,19 +66,15 @@ public class JoinLAsscomProject extends
OneExplorationRuleFactory {
@Override
public Rule build() {
return logicalJoin(logicalProject(logicalJoin()), group())
- .when(JoinLAsscomHelper::check)
- .when(join -> join.getJoinType().isInnerJoin() ||
join.getJoinType().isLeftOuterJoin()
- && (join.left().child().getJoinType().isInnerJoin() ||
join.left().child().getJoinType()
- .isLeftOuterJoin()))
- .then(topJoin -> {
- LogicalJoin<GroupPlan, GroupPlan> bottomJoin =
topJoin.left().child();
-
- JoinLAsscomHelper helper = JoinLAsscomHelper.of(topJoin,
bottomJoin);
- if (!helper.initJoinOnCondition()) {
- return null;
- }
-
- return helper.newProjectTopJoin();
- }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
+ .when(topJoin -> JoinLAsscomHelper.check(type, topJoin,
topJoin.left().child()))
+ .when(typeChecker)
+ .then(topJoin -> {
+ JoinLAsscomHelper helper = new JoinLAsscomHelper(topJoin,
topJoin.left().child());
+ helper.initAllProject(topJoin.left());
+ if (!helper.initJoinOnCondition()) {
+ return null;
+ }
+ return helper.newTopJoin();
+ }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteHelper.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderCommon.java
similarity index 54%
rename from
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteHelper.java
rename to
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderCommon.java
index 47da5030e1..2a4f81138a 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteHelper.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderCommon.java
@@ -17,32 +17,24 @@
package org.apache.doris.nereids.rules.exploration.join;
-import org.apache.doris.nereids.trees.plans.GroupPlan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-/**
- * Common function for JoinCommute
- */
-public class JoinCommuteHelper {
+import java.util.List;
+import java.util.Optional;
- enum SwapType {
- BOTTOM_JOIN, ZIG_ZAG, ALL
+class JoinReorderCommon {
+ public enum Type {
+ INNER,
+ OUTER
}
- private final boolean swapOuter;
- private final SwapType swapType;
-
- public JoinCommuteHelper(boolean swapOuter, SwapType swapType) {
- this.swapOuter = swapOuter;
- this.swapType = swapType;
- }
-
- public static boolean check(LogicalJoin<GroupPlan, GroupPlan> join) {
- return !join.getJoinReorderContext().hasCommute() &&
!join.getJoinReorderContext().hasExchange();
- }
-
- public static boolean check(LogicalProject<LogicalJoin<GroupPlan,
GroupPlan>> project) {
- return check(project.child());
+ public static Optional<Plan> project(List<NamedExpression> projectExprs,
Plan plan) {
+ if (!projectExprs.isEmpty()) {
+ return Optional.of(new LogicalProject<>(projectExprs, plan));
+ } else {
+ return Optional.empty();
+ }
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/JoinReorderContext.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderContext.java
similarity index 87%
rename from
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/JoinReorderContext.java
rename to
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderContext.java
index 8384934d42..44166b625f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/JoinReorderContext.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderContext.java
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-package org.apache.doris.nereids.rules.exploration;
+package org.apache.doris.nereids.rules.exploration.join;
/**
* JoinReorderContext for Duplicate free.
@@ -26,6 +26,7 @@ package org.apache.doris.nereids.rules.exploration;
public class JoinReorderContext {
// left deep tree
private boolean hasCommute = false;
+ private boolean hasLAsscom = false;
// zig-zag tree
private boolean hasCommuteZigZag = false;
@@ -38,16 +39,24 @@ public class JoinReorderContext {
public JoinReorderContext() {
}
+ /**
+ * copy a JoinReorderContext.
+ */
public void copyFrom(JoinReorderContext joinReorderContext) {
this.hasCommute = joinReorderContext.hasCommute;
+ this.hasLAsscom = joinReorderContext.hasLAsscom;
this.hasExchange = joinReorderContext.hasExchange;
this.hasLeftAssociate = joinReorderContext.hasLeftAssociate;
this.hasRightAssociate = joinReorderContext.hasRightAssociate;
this.hasCommuteZigZag = joinReorderContext.hasCommuteZigZag;
}
+ /**
+ * clear all.
+ */
public void clear() {
hasCommute = false;
+ hasLAsscom = false;
hasCommuteZigZag = false;
hasExchange = false;
hasRightAssociate = false;
@@ -62,6 +71,14 @@ public class JoinReorderContext {
this.hasCommute = hasCommute;
}
+ public boolean hasLAsscom() {
+ return hasLAsscom;
+ }
+
+ public void setHasLAsscom(boolean hasLAsscom) {
+ this.hasLAsscom = hasLAsscom;
+ }
+
public boolean hasExchange() {
return hasExchange;
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/ThreeJoinHelper.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/ThreeJoinHelper.java
new file mode 100644
index 0000000000..fdf70f2c05
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/ThreeJoinHelper.java
@@ -0,0 +1,165 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration.join;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * Common join helper for three-join.
+ */
+abstract class ThreeJoinHelper {
+ protected final LogicalJoin<? extends Plan, ? extends Plan> topJoin;
+ protected final LogicalJoin<GroupPlan, GroupPlan> bottomJoin;
+ protected final GroupPlan a;
+ protected final GroupPlan b;
+ protected final GroupPlan c;
+
+ protected final List<SlotReference> aOutput;
+ protected final List<SlotReference> bOutput;
+ protected final List<SlotReference> cOutput;
+
+ protected final List<NamedExpression> allProjects = Lists.newArrayList();
+
+ protected final List<Expression> allHashJoinConjuncts =
Lists.newArrayList();
+ protected final List<Expression> allNonHashJoinConjuncts =
Lists.newArrayList();
+
+ protected final List<Expression> newBottomHashJoinConjuncts =
Lists.newArrayList();
+ protected final List<Expression> newBottomNonHashJoinConjuncts =
Lists.newArrayList();
+
+ protected final List<Expression> newTopHashJoinConjuncts =
Lists.newArrayList();
+ protected final List<Expression> newTopNonHashJoinConjuncts =
Lists.newArrayList();
+
+ /**
+ * Init plan and output.
+ */
+ public ThreeJoinHelper(LogicalJoin<? extends Plan, ? extends Plan> topJoin,
+ LogicalJoin<GroupPlan, GroupPlan> bottomJoin, GroupPlan a,
GroupPlan b, GroupPlan c) {
+ this.topJoin = topJoin;
+ this.bottomJoin = bottomJoin;
+ this.a = a;
+ this.b = b;
+ this.c = c;
+
+ aOutput = Utils.getOutputSlotReference(a);
+ bOutput = Utils.getOutputSlotReference(b);
+ cOutput = Utils.getOutputSlotReference(c);
+
+ Preconditions.checkArgument(!topJoin.getHashJoinConjuncts().isEmpty(),
"topJoin hashJoinConjuncts must exist.");
+
Preconditions.checkArgument(!bottomJoin.getHashJoinConjuncts().isEmpty(),
+ "bottomJoin hashJoinConjuncts must exist.");
+
+ allHashJoinConjuncts.addAll(topJoin.getHashJoinConjuncts());
+ allHashJoinConjuncts.addAll(bottomJoin.getHashJoinConjuncts());
+ topJoin.getOtherJoinCondition().ifPresent(otherJoinCondition ->
allNonHashJoinConjuncts.addAll(
+ ExpressionUtils.extractConjunction(otherJoinCondition)));
+ bottomJoin.getOtherJoinCondition().ifPresent(otherJoinCondition ->
allNonHashJoinConjuncts.addAll(
+ ExpressionUtils.extractConjunction(otherJoinCondition)));
+ }
+
+ @SafeVarargs
+ public final void initAllProject(LogicalProject<? extends Plan>...
projects) {
+ for (LogicalProject<? extends Plan> project : projects) {
+ allProjects.addAll(project.getProjects());
+ }
+ }
+
+ /**
+ * Get the onCondition of newTopJoin and newBottomJoin.
+ */
+ public boolean initJoinOnCondition() {
+ // Ignore join with some OnClause like:
+ // Join C = B + A for above example.
+ // TODO: also need for otherJoinCondition
+ for (Expression topJoinOnClauseConjunct :
topJoin.getHashJoinConjuncts()) {
+ Set<SlotReference> topJoinUsedSlot =
topJoinOnClauseConjunct.collect(SlotReference.class::isInstance);
+ if (ExpressionUtils.isIntersecting(topJoinUsedSlot, aOutput) &&
ExpressionUtils.isIntersecting(
+ topJoinUsedSlot, bOutput) &&
ExpressionUtils.isIntersecting(topJoinUsedSlot, cOutput)) {
+ return false;
+ }
+ }
+
+ Set<Slot> newBottomJoinSlots = new HashSet<>(aOutput);
+ newBottomJoinSlots.addAll(cOutput);
+ for (Expression hashConjunct : allHashJoinConjuncts) {
+ Set<SlotReference> slots =
hashConjunct.collect(SlotReference.class::isInstance);
+ if (newBottomJoinSlots.containsAll(slots)) {
+ newBottomHashJoinConjuncts.add(hashConjunct);
+ } else {
+ newTopHashJoinConjuncts.add(hashConjunct);
+ }
+ }
+ for (Expression nonHashConjunct : allNonHashJoinConjuncts) {
+ Set<SlotReference> slots =
nonHashConjunct.collect(SlotReference.class::isInstance);
+ if (newBottomJoinSlots.containsAll(slots)) {
+ newBottomNonHashJoinConjuncts.add(nonHashConjunct);
+ } else {
+ newTopNonHashJoinConjuncts.add(nonHashConjunct);
+ }
+ }
+ // newBottomJoinOnCondition/newTopJoinOnCondition is empty. They are
cross join.
+ // Example:
+ // A: col1, col2. B: col2, col3. C: col3, col4
+ // (A & B on A.col2=B.col2) & C on B.col3=C.col3.
+ // (A & B) & C -> (A & C) & B.
+ // (A & C) will be cross join (newBottomJoinOnCondition is empty)
+ if (newBottomHashJoinConjuncts.isEmpty() ||
newTopHashJoinConjuncts.isEmpty()) {
+ return false;
+ }
+
+ return true;
+ }
+
+ /**
+ * Split inside-project into two part.
+ *
+ * @param topJoinChild output of topJoin groupPlan child.
+ */
+ protected Pair<List<NamedExpression>, List<NamedExpression>>
splitProjectExprs(List<SlotReference> topJoinChild) {
+ List<NamedExpression> newTopJoinChildProjectExprs =
Lists.newArrayList();
+ List<NamedExpression> newBottomJoinProjectExprs = Lists.newArrayList();
+
+ HashSet<SlotReference> topJoinOutputSlotsSet = new
HashSet<>(topJoinChild);
+
+ for (NamedExpression projectExpr : allProjects) {
+ Set<SlotReference> usedSlotRefs =
projectExpr.collect(SlotReference.class::isInstance);
+ if (topJoinOutputSlotsSet.containsAll(usedSlotRefs)) {
+ newTopJoinChildProjectExprs.add(projectExpr);
+ } else {
+ newBottomJoinProjectExprs.add(projectExpr);
+ }
+ }
+ return Pair.of(newTopJoinChildProjectExprs, newBottomJoinProjectExprs);
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
index 12052cbeb5..61f34a84c3 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
@@ -19,7 +19,7 @@ package org.apache.doris.nereids.trees.plans.logical;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
-import org.apache.doris.nereids.rules.exploration.JoinReorderContext;
+import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index 7f22a4b2fb..a716e3e43f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Or;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import com.google.common.base.Preconditions;
@@ -29,6 +30,7 @@ import com.google.common.collect.Sets;
import java.util.List;
import java.util.Objects;
+import java.util.Optional;
import java.util.Set;
/**
@@ -76,6 +78,14 @@ public class ExpressionUtils {
}
}
+ public static Optional<Expression> andByOptional(List<Expression>
expressions) {
+ if (expressions.isEmpty()) {
+ return Optional.empty();
+ } else {
+ return Optional.of(ExpressionUtils.and(expressions));
+ }
+ }
+
public static Expression and(List<Expression> expressions) {
return combine(And.class, expressions);
}
@@ -120,4 +130,16 @@ public class ExpressionUtils {
.reduce(type == And.class ? And::new : Or::new)
.orElse(new BooleanLiteral(type == And.class));
}
+
+ /**
+ * Check whether lhs and rhs are intersecting.
+ */
+ public static boolean isIntersecting(Set<SlotReference> lhs,
List<SlotReference> rhs) {
+ for (SlotReference rh : rhs) {
+ if (lhs.contains(rh)) {
+ return true;
+ }
+ }
+ return false;
+ }
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java
index 27f7cbeae7..46e31ead85 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java
@@ -17,8 +17,8 @@
package org.apache.doris.nereids.rules.exploration.join;
-import org.apache.doris.nereids.CascadesContext;
-import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.memo.Group;
+import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
@@ -26,8 +26,10 @@ import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
@@ -35,7 +37,6 @@ import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
-import java.util.List;
import java.util.Optional;
public class JoinCommuteTest {
@@ -51,14 +52,23 @@ public class JoinCommuteTest {
JoinType.INNER_JOIN, Lists.newArrayList(onCondition),
Optional.empty(), scan1, scan2);
- CascadesContext cascadesContext =
MemoTestUtils.createCascadesContext(join);
- Rule rule = new JoinCommute(true).build();
+ PlanChecker.from(MemoTestUtils.createConnectContext(), join)
+ .transform(JoinCommute.OUTER_LEFT_DEEP.build())
+ .checkMemo(memo -> {
+ Group root = memo.getRoot();
+ Assertions.assertEquals(2,
root.getLogicalExpressions().size());
- List<Plan> transform = rule.transform(join, cascadesContext);
- Assertions.assertEquals(1, transform.size());
- Plan newJoin = transform.get(0);
+
Assertions.assertTrue(root.logicalExpressionsAt(0).getPlan() instanceof
LogicalJoin);
+
Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof
LogicalProject);
- Assertions.assertEquals(join.child(0), newJoin.child(1));
- Assertions.assertEquals(join.child(1), newJoin.child(0));
+ GroupExpression newJoinGroupExpr =
root.logicalExpressionsAt(1).child(0).getLogicalExpression();
+ Plan left =
newJoinGroupExpr.child(0).getLogicalExpression().getPlan();
+ Plan right =
newJoinGroupExpr.child(1).getLogicalExpression().getPlan();
+ Assertions.assertTrue(left instanceof LogicalOlapScan);
+ Assertions.assertTrue(right instanceof LogicalOlapScan);
+
+ Assertions.assertEquals("t2", ((LogicalOlapScan)
left).getTable().getName());
+ Assertions.assertEquals("t1", ((LogicalOlapScan)
right).getTable().getName());
+ });
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProjectTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProjectTest.java
deleted file mode 100644
index cb70125c43..0000000000
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProjectTest.java
+++ /dev/null
@@ -1,134 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.exploration.join;
-
-import org.apache.doris.common.Pair;
-import org.apache.doris.nereids.CascadesContext;
-import org.apache.doris.nereids.rules.Rule;
-import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
-import org.apache.doris.nereids.trees.plans.JoinType;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-import org.apache.doris.nereids.util.MemoTestUtils;
-import org.apache.doris.nereids.util.PlanConstructor;
-import org.apache.doris.nereids.util.Utils;
-
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Assertions;
-import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.Test;
-
-import java.util.List;
-import java.util.Optional;
-
-public class JoinLAsscomProjectTest {
-
- private static final List<LogicalOlapScan> scans = Lists.newArrayList();
- private static final List<List<SlotReference>> outputs =
Lists.newArrayList();
-
- @BeforeAll
- public static void init() {
- LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
- LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
- LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
-
- scans.add(scan1);
- scans.add(scan2);
- scans.add(scan3);
-
- List<SlotReference> t1Output = Utils.getOutputSlotReference(scan1);
- List<SlotReference> t2Output = Utils.getOutputSlotReference(scan2);
- List<SlotReference> t3Output = Utils.getOutputSlotReference(scan3);
-
- outputs.add(t1Output);
- outputs.add(t2Output);
- outputs.add(t3Output);
- }
-
- private Pair<LogicalJoin, LogicalJoin>
testJoinProjectLAsscom(List<NamedExpression> projects) {
- /*
- * topJoin newTopJoin
- * / \ / \
- * project C newLeftProject newRightProject
- * / ──► / \
- * bottomJoin newBottomJoin B
- * / \ / \
- * A B A C
- */
-
- Assertions.assertEquals(3, scans.size());
-
- List<SlotReference> t1 = outputs.get(0);
- List<SlotReference> t2 = outputs.get(1);
- List<SlotReference> t3 = outputs.get(2);
- Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0));
- Expression topJoinOnCondition = new EqualTo(t1.get(1), t3.get(1));
-
- LogicalProject<LogicalJoin<LogicalOlapScan, LogicalOlapScan>> project
= new LogicalProject<>(
- projects,
- new LogicalJoin<>(JoinType.INNER_JOIN,
Lists.newArrayList(bottomJoinOnCondition),
- Optional.empty(), scans.get(0), scans.get(1)));
-
- LogicalJoin<LogicalProject<LogicalJoin<LogicalOlapScan,
LogicalOlapScan>>, LogicalOlapScan> topJoin
- = new LogicalJoin<>(JoinType.INNER_JOIN,
Lists.newArrayList(topJoinOnCondition),
- Optional.empty(), project, scans.get(2));
-
- CascadesContext cascadesContext =
MemoTestUtils.createCascadesContext(topJoin);
- Rule rule = new JoinLAsscomProject().build();
- List<Plan> transform = rule.transform(topJoin, cascadesContext);
- Assertions.assertEquals(1, transform.size());
- Assertions.assertTrue(transform.get(0) instanceof LogicalJoin);
- LogicalJoin newTopJoin = (LogicalJoin) transform.get(0);
- return Pair.of(topJoin, newTopJoin);
- }
-
- @Test
- public void testStarJoinProjectLAsscom() {
- List<SlotReference> t1 = outputs.get(0);
- List<SlotReference> t2 = outputs.get(1);
- List<NamedExpression> projects = ImmutableList.of(
- new Alias(t2.get(0), "t2.id"),
- new Alias(t1.get(0), "t1.id"),
- t1.get(1),
- t2.get(1)
- );
-
- Pair<LogicalJoin, LogicalJoin> pair = testJoinProjectLAsscom(projects);
-
- LogicalJoin oldJoin = pair.first;
- LogicalJoin newTopJoin = pair.second;
-
- // Join reorder successfully.
- Assertions.assertNotEquals(oldJoin, newTopJoin);
- Assertions.assertEquals("t1.id", ((Alias) ((LogicalProject)
newTopJoin.left()).getProjects().get(0)).getName());
- Assertions.assertEquals("name",
- ((SlotReference) ((LogicalProject)
newTopJoin.left()).getProjects().get(1)).getName());
- Assertions.assertEquals("t2.id",
- ((Alias) ((LogicalProject)
newTopJoin.right()).getProjects().get(0)).getName());
- Assertions.assertEquals("name",
- ((SlotReference) ((LogicalProject)
newTopJoin.left()).getProjects().get(1)).getName());
-
- }
-}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java
index 6ed20fa125..7d7f1d8b05 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java
@@ -17,24 +17,23 @@
package org.apache.doris.nereids.rules.exploration.join;
-import org.apache.doris.common.Pair;
-import org.apache.doris.nereids.CascadesContext;
-import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.memo.Group;
+import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
-import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import java.util.List;
@@ -42,55 +41,13 @@ import java.util.Optional;
public class JoinLAsscomTest {
- private static List<LogicalOlapScan> scans = Lists.newArrayList();
- private static List<List<SlotReference>> outputs = Lists.newArrayList();
-
- @BeforeAll
- public static void init() {
- LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
- LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
- LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
-
- scans.add(scan1);
- scans.add(scan2);
- scans.add(scan3);
-
- List<SlotReference> t1Output = Utils.getOutputSlotReference(scan1);
- List<SlotReference> t2Output = Utils.getOutputSlotReference(scan2);
- List<SlotReference> t3Output = Utils.getOutputSlotReference(scan3);
- outputs.add(t1Output);
- outputs.add(t2Output);
- outputs.add(t3Output);
- }
+ private final LogicalOlapScan scan1 =
PlanConstructor.newLogicalOlapScan(0, "t1", 0);
+ private final LogicalOlapScan scan2 =
PlanConstructor.newLogicalOlapScan(1, "t2", 0);
+ private final LogicalOlapScan scan3 =
PlanConstructor.newLogicalOlapScan(2, "t3", 0);
- public Pair<LogicalJoin, LogicalJoin> testJoinLAsscom(
- Expression bottomJoinOnCondition,
- Expression bottomNonHashExpression,
- Expression topJoinOnCondition,
- Expression topNonHashExpression) {
- /*
- * topJoin newTopJoin
- * / \ / \
- * bottomJoin C --> newBottomJoin B
- * / \ / \
- * A B A C
- */
- Assertions.assertEquals(3, scans.size());
- LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new
LogicalJoin<>(JoinType.INNER_JOIN,
- Lists.newArrayList(bottomJoinOnCondition),
- Optional.of(bottomNonHashExpression), scans.get(0),
scans.get(1));
- LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>,
LogicalOlapScan> topJoin = new LogicalJoin<>(
- JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition),
- Optional.of(topNonHashExpression), bottomJoin, scans.get(2));
-
- CascadesContext cascadesContext =
MemoTestUtils.createCascadesContext(topJoin);
- Rule rule = new JoinLAsscom().build();
- List<Plan> transform = rule.transform(topJoin, cascadesContext);
- Assertions.assertEquals(1, transform.size());
- Assertions.assertTrue(transform.get(0) instanceof LogicalJoin);
- LogicalJoin newTopJoin = (LogicalJoin) transform.get(0);
- return Pair.of(topJoin, newTopJoin);
- }
+ private final List<SlotReference> t1Output =
Utils.getOutputSlotReference(scan1);
+ private final List<SlotReference> t2Output =
Utils.getOutputSlotReference(scan2);
+ private final List<SlotReference> t3Output =
Utils.getOutputSlotReference(scan3);
@Test
public void testStarJoinLAsscom() {
@@ -109,31 +66,35 @@ public class JoinLAsscomTest {
* t1 t2 t1 t3
*/
- List<SlotReference> t1 = outputs.get(0);
- List<SlotReference> t2 = outputs.get(1);
- List<SlotReference> t3 = outputs.get(2);
- Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0));
- Expression bottomNonHashExpression = new LessThan(t1.get(0),
t2.get(0));
- Expression topJoinOnCondition = new EqualTo(t1.get(1), t3.get(1));
- Expression topNonHashCondition = new LessThan(t1.get(1), t3.get(1));
-
- Pair<LogicalJoin, LogicalJoin> pair = testJoinLAsscom(
- bottomJoinOnCondition,
- bottomNonHashExpression,
- topJoinOnCondition,
- topNonHashCondition);
- LogicalJoin oldJoin = pair.first;
- LogicalJoin newTopJoin = pair.second;
-
- // Join reorder successfully.
- Assertions.assertNotEquals(oldJoin, newTopJoin);
- Assertions.assertEquals("t1",
- ((LogicalOlapScan) ((LogicalJoin)
newTopJoin.left()).left()).getTable().getName());
- Assertions.assertEquals("t3",
- ((LogicalOlapScan) ((LogicalJoin)
newTopJoin.left()).right()).getTable().getName());
- Assertions.assertEquals("t2", ((LogicalOlapScan)
newTopJoin.right()).getTable().getName());
- Assertions.assertEquals(newTopJoin.getOtherJoinCondition(),
- ((LogicalJoin) oldJoin.child(0)).getOtherJoinCondition());
+ Expression bottomJoinOnCondition = new EqualTo(t1Output.get(0),
t2Output.get(0));
+ Expression topJoinOnCondition = new EqualTo(t1Output.get(1),
t3Output.get(1));
+
+ LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new
LogicalJoin<>(JoinType.INNER_JOIN,
+ Lists.newArrayList(bottomJoinOnCondition),
+ Optional.empty(), scan1, scan2);
+ LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>,
LogicalOlapScan> topJoin = new LogicalJoin<>(
+ JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition),
+ Optional.empty(), bottomJoin, scan3);
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
+ .transform(JoinLAsscom.INNER.build())
+ .checkMemo(memo -> {
+ Group root = memo.getRoot();
+ Assertions.assertEquals(2,
root.getLogicalExpressions().size());
+
+
Assertions.assertTrue(root.logicalExpressionsAt(0).getPlan() instanceof
LogicalJoin);
+
Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof
LogicalProject);
+
+ GroupExpression newTopJoinGroupExpr =
root.logicalExpressionsAt(1).child(0).getLogicalExpression();
+ GroupExpression newBottomJoinGroupExpr =
newTopJoinGroupExpr.child(0).getLogicalExpression();
+ Plan bottomLeft =
newBottomJoinGroupExpr.child(0).getLogicalExpression().getPlan();
+ Plan bottomRight =
newBottomJoinGroupExpr.child(1).getLogicalExpression().getPlan();
+ Plan right =
newTopJoinGroupExpr.child(1).getLogicalExpression().getPlan();
+
+ Assertions.assertEquals("t1", ((LogicalOlapScan)
bottomLeft).getTable().getName());
+ Assertions.assertEquals("t3", ((LogicalOlapScan)
bottomRight).getTable().getName());
+ Assertions.assertEquals("t2", ((LogicalOlapScan)
right).getTable().getName());
+ });
}
@Test
@@ -151,27 +112,22 @@ public class JoinLAsscomTest {
* t1 t2 t1 t3
*/
- List<SlotReference> t1 = outputs.get(0);
- List<SlotReference> t2 = outputs.get(1);
- List<SlotReference> t3 = outputs.get(2);
- Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0));
- Expression bottomNonHashExpression = new LessThan(t1.get(0),
t2.get(0));
- Expression topJoinOnCondition = new EqualTo(t2.get(0), t3.get(0));
- Expression topNonHashExpression = new LessThan(t2.get(0), t3.get(0));
-
- Pair<LogicalJoin, LogicalJoin> pair =
testJoinLAsscom(bottomJoinOnCondition, bottomNonHashExpression,
- topJoinOnCondition, topNonHashExpression);
- LogicalJoin oldJoin = pair.first;
- LogicalJoin newTopJoin = pair.second;
-
- // Join reorder failed.
- // Chain-Join LAsscom directly will be failed.
- // After t1 -- t2 -- t3
- // -- join commute -->
- // t1 -- t2
- // |
- // t3
- // then, we can LAsscom for this star-join.
- Assertions.assertEquals(oldJoin, newTopJoin);
+ Expression bottomJoinOnCondition = new EqualTo(t1Output.get(0),
t2Output.get(0));
+ Expression topJoinOnCondition = new EqualTo(t2Output.get(0),
t3Output.get(0));
+ LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new
LogicalJoin<>(JoinType.INNER_JOIN,
+ Lists.newArrayList(bottomJoinOnCondition),
+ Optional.empty(), scan1, scan2);
+ LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>,
LogicalOlapScan> topJoin = new LogicalJoin<>(
+ JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition),
+ Optional.empty(), bottomJoin, scan3);
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
+ .transform(JoinLAsscom.INNER.build())
+ .checkMemo(memo -> {
+ Group root = memo.getRoot();
+
+ // TODO: need infer onCondition.
+ Assertions.assertEquals(1,
root.getLogicalExpressions().size());
+ });
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
index 0f43ed09eb..3b03d468a6 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
@@ -23,7 +23,6 @@ import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.pattern.GroupExpressionMatching;
-import
org.apache.doris.nereids.pattern.GroupExpressionMatching.GroupExpressionIterator;
import org.apache.doris.nereids.pattern.MatchingContext;
import org.apache.doris.nereids.pattern.PatternDescriptor;
import org.apache.doris.nereids.pattern.PatternMatcher;
@@ -145,10 +144,8 @@ public class PlanChecker {
public PlanChecker transform(GroupExpression groupExpression,
PatternMatcher patternMatcher) {
GroupExpressionMatching matchResult = new
GroupExpressionMatching(patternMatcher.pattern, groupExpression);
- GroupExpressionIterator iterator = matchResult.iterator();
- while (iterator.hasNext()) {
- Plan before = iterator.next();
+ for (Plan before : matchResult) {
Plan after = patternMatcher.matchedAction.apply(
new MatchingContext(before, patternMatcher.pattern,
cascadesContext));
if (before != after) {
@@ -162,6 +159,38 @@ public class PlanChecker {
return this;
}
+ public PlanChecker transform(Rule rule) {
+ return transform(cascadesContext.getMemo().getRoot(), rule);
+ }
+
+ public PlanChecker transform(Group group, Rule rule) {
+ // copy groupExpressions can prevent ConcurrentModificationException
+ for (GroupExpression logicalExpression :
Lists.newArrayList(group.getLogicalExpressions())) {
+ transform(logicalExpression, rule);
+ }
+
+ for (GroupExpression physicalExpression :
Lists.newArrayList(group.getPhysicalExpressions())) {
+ transform(physicalExpression, rule);
+ }
+ return this;
+ }
+
+ public PlanChecker transform(GroupExpression groupExpression, Rule rule) {
+ GroupExpressionMatching matchResult = new
GroupExpressionMatching(rule.getPattern(), groupExpression);
+
+ for (Plan before : matchResult) {
+ Plan after = rule.transform(before, cascadesContext).get(0);
+ if (before != after) {
+ cascadesContext.getMemo().copyIn(after,
before.getGroupExpression().get().getOwnerGroup(), false);
+ }
+ }
+
+ for (Group childGroup : groupExpression.children()) {
+ transform(childGroup, rule);
+ }
+ return this;
+ }
+
public PlanChecker matchesFromRoot(PatternDescriptor<? extends Plan>
patternDesc) {
Memo memo = cascadesContext.getMemo();
assertMatches(memo, () -> new
GroupExpressionMatching(patternDesc.pattern,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]