This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new d4cebb39ba [fix](Nereids): fix SemiJoinLogicalJoinTransposeProject.
(#16883)
d4cebb39ba is described below
commit d4cebb39ba3d7c227947daa95a494282b60de857
Author: jakevin <[email protected]>
AuthorDate: Sat Feb 18 23:12:34 2023 +0800
[fix](Nereids): fix SemiJoinLogicalJoinTransposeProject. (#16883)
---
.../rules/exploration/join/OuterJoinLAsscom.java | 2 +-
.../join/SemiJoinLogicalJoinTransposeProject.java | 95 +++++++++++++---------
.../nereids/trees/plans/logical/LogicalJoin.java | 1 -
3 files changed, 56 insertions(+), 42 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java
index dda781f327..a23c3f0015 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java
@@ -63,8 +63,8 @@ public class OuterJoinLAsscom extends
OneExplorationRuleFactory {
return logicalJoin(logicalJoin(), group())
.when(join ->
VALID_TYPE_PAIR_SET.contains(Pair.of(join.left().getJoinType(),
join.getJoinType())))
.when(topJoin -> checkReorder(topJoin, topJoin.left()))
- .when(topJoin -> checkCondition(topJoin,
topJoin.left().right().getOutputExprIdSet()))
.whenNot(join -> join.hasJoinHint() ||
join.left().hasJoinHint())
+ .when(topJoin -> checkCondition(topJoin,
topJoin.left().right().getOutputExprIdSet()))
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin =
topJoin.left();
GroupPlan a = bottomJoin.left();
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
index 76fda42fe5..89a843fffc 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
@@ -17,11 +17,13 @@
package org.apache.doris.nereids.rules.exploration.join;
+import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinHint;
@@ -33,9 +35,12 @@ import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
-import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
/**
* <ul>
@@ -64,7 +69,6 @@ public class SemiJoinLogicalJoinTransposeProject extends
OneExplorationRuleFacto
.whenNot(topJoin ->
topJoin.left().child().getJoinType().isSemiOrAntiJoin())
.whenNot(join -> join.hasJoinHint() ||
join.left().child().hasJoinHint())
.when(join -> JoinReorderUtils.checkProject(join.left()))
- .when(this::conditionChecker)
.then(topSemiJoin -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project
= topSemiJoin.left();
LogicalJoin<GroupPlan, GroupPlan> bottomJoin =
project.child();
@@ -72,17 +76,17 @@ public class SemiJoinLogicalJoinTransposeProject extends
OneExplorationRuleFacto
GroupPlan b = bottomJoin.right();
GroupPlan c = topSemiJoin.right();
- Set<ExprId> aOutputExprIdSet = a.getOutputExprIdSet();
-
- List<Expression> hashJoinConjuncts =
topSemiJoin.getHashJoinConjuncts();
-
- boolean lasscom = false;
- for (Expression hashJoinConjunct : hashJoinConjuncts) {
- Set<ExprId> usedSlotExprIdSet =
hashJoinConjunct.getInputSlotExprIds();
- lasscom = Utils.isIntersecting(usedSlotExprIdSet,
aOutputExprIdSet) || lasscom;
+ // push topSemiJoin down project, so we need replace
conjuncts by project.
+ Pair<List<Expression>, List<Expression>> conjuncts =
replaceConjuncts(topSemiJoin, project);
+ Set<ExprId> conjunctsIds =
Stream.concat(conjuncts.first.stream(), conjuncts.second.stream())
+ .flatMap(expr ->
expr.getInputSlotExprIds().stream()).collect(Collectors.toSet());
+ ContainsType containsType = containsChildren(conjunctsIds,
a.getOutputExprIdSet(),
+ b.getOutputExprIdSet());
+ if (containsType == ContainsType.ALL) {
+ return null;
}
- if (lasscom) {
+ if (containsType == ContainsType.LEFT) {
/*-
* topSemiJoin project
* / \ |
@@ -92,22 +96,24 @@ public class SemiJoinLogicalJoinTransposeProject extends
OneExplorationRuleFacto
* / \ / \
* A B A C
*/
+ // Preconditions.checkState(bottomJoin.getJoinType()
!= JoinType.RIGHT_OUTER_JOIN);
if (bottomJoin.getJoinType() ==
JoinType.RIGHT_OUTER_JOIN) {
// when bottom join is right outer join, we change
it to inner join
// if we want to do this trans. However, we do not
allow different logical properties
// in one group. So we need to change it to inner
join in rewrite step.
- return topSemiJoin;
+ return null;
}
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin =
new LogicalJoin<>(
- topSemiJoin.getJoinType(),
topSemiJoin.getHashJoinConjuncts(),
- topSemiJoin.getOtherJoinConjuncts(),
JoinHint.NONE, a, c);
+ topSemiJoin.getJoinType(), conjuncts.first,
conjuncts.second, JoinHint.NONE, a, c);
LogicalJoin<Plan, Plan> newTopJoin = new
LogicalJoin<>(bottomJoin.getJoinType(),
bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinConjuncts(),
- JoinHint.NONE,
- newBottomSemiJoin, b);
- return JoinReorderUtils.projectOrSelf(new
ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
+ JoinHint.NONE, newBottomSemiJoin, b);
+ return project.withChildren(newTopJoin);
} else {
+ if (leftDeep) {
+ return null;
+ }
/*-
* topSemiJoin project
* / \ |
@@ -121,40 +127,49 @@ public class SemiJoinLogicalJoinTransposeProject extends
OneExplorationRuleFacto
// when bottom join is left outer join, we change
it to inner join
// if we want to do this trans. However, we do not
allow different logical properties
// in one group. So we need to change it to inner
join in rewrite step.
- return topSemiJoin;
+ return null;
}
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin =
new LogicalJoin<>(
- topSemiJoin.getJoinType(),
topSemiJoin.getHashJoinConjuncts(),
- topSemiJoin.getOtherJoinConjuncts(),
JoinHint.NONE, b, c);
+ topSemiJoin.getJoinType(), conjuncts.first,
conjuncts.second, JoinHint.NONE, b, c);
LogicalJoin<Plan, Plan> newTopJoin = new
LogicalJoin<>(bottomJoin.getJoinType(),
bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinConjuncts(),
- JoinHint.NONE,
- a, newBottomSemiJoin);
- return JoinReorderUtils.projectOrSelf(new
ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
+ JoinHint.NONE, a, newBottomSemiJoin);
+ return project.withChildren(newTopJoin);
}
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT);
}
- // project of bottomJoin just return A OR B, else return false.
- private boolean conditionChecker(
- LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>,
GroupPlan> topSemiJoin) {
- List<Expression> hashJoinConjuncts =
topSemiJoin.getHashJoinConjuncts();
+ private Pair<List<Expression>, List<Expression>>
replaceConjuncts(LogicalJoin<? extends Plan, ? extends Plan> join,
+ LogicalProject<? extends Plan> project) {
+ Map<ExprId, Slot> outputToInput = new HashMap<>();
+ for (NamedExpression outputExpr : project.getProjects()) {
+ Set<Slot> usedSlots = outputExpr.getInputSlots();
+ Preconditions.checkState(usedSlots.size() == 1);
+ Slot inputSlot = usedSlots.iterator().next();
+ outputToInput.put(outputExpr.getExprId(), inputSlot);
+ }
+ List<Expression> topHashConjuncts =
+
JoinReorderUtils.replaceJoinConjuncts(join.getHashJoinConjuncts(),
outputToInput);
+ List<Expression> topOtherConjuncts =
+
JoinReorderUtils.replaceJoinConjuncts(join.getOtherJoinConjuncts(),
outputToInput);
+ return Pair.of(topHashConjuncts, topOtherConjuncts);
+ }
- List<Slot> aOutput = topSemiJoin.left().child().left().getOutput();
- List<Slot> bOutput = topSemiJoin.left().child().right().getOutput();
+ enum ContainsType {
+ LEFT, RIGHT, ALL
+ }
- boolean hashContainsA = false;
- boolean hashContainsB = false;
- for (Expression hashJoinConjunct : hashJoinConjuncts) {
- Set<Slot> usedSlot =
hashJoinConjunct.collect(Slot.class::isInstance);
- hashContainsA = Utils.isIntersecting(usedSlot, aOutput) ||
hashContainsA;
- hashContainsB = Utils.isIntersecting(usedSlot, bOutput) ||
hashContainsB;
- }
- if (leftDeep && hashContainsB) {
- return false;
+ private ContainsType containsChildren(Set<ExprId> conjunctsExprIdSet,
Set<ExprId> left, Set<ExprId> right) {
+ boolean containsLeft = Utils.isIntersecting(conjunctsExprIdSet, left);
+ boolean containsRight = Utils.isIntersecting(conjunctsExprIdSet,
right);
+ Preconditions.checkState(containsLeft || containsRight, "join output
must contain child");
+ if (containsLeft && containsRight) {
+ return ContainsType.ALL;
+ } else if (containsLeft) {
+ return ContainsType.LEFT;
+ } else {
+ return ContainsType.RIGHT;
}
- Preconditions.checkState(hashContainsA || hashContainsB, "join output
must contain child");
- return !(hashContainsA && hashContainsB);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
index 72018b914b..4c187674ef 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
@@ -130,7 +130,6 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan,
RIGHT_CHILD_TYPE extends
return otherJoinConjuncts;
}
- @Override
public List<Expression> getHashJoinConjuncts() {
return hashJoinConjuncts;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]