This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 2b6133f4d0 [feature](Nereids): pushdown complex project through
inner/outer Join. (#17365)
2b6133f4d0 is described below
commit 2b6133f4d0742ffa2446a46eee278cb909cd6de4
Author: jakevin <[email protected]>
AuthorDate: Wed Mar 8 12:00:56 2023 +0800
[feature](Nereids): pushdown complex project through inner/outer Join.
(#17365)
---
.../org/apache/doris/nereids/rules/RuleType.java | 1 +
.../rules/exploration/join/JoinReorderUtils.java | 23 ++--
.../join/PushdownProjectThroughInnerJoin.java | 104 ++++++++++++++
.../join/PushdownProjectThroughSemiJoin.java | 36 +++--
.../join/PushdownProjectThroughInnerJoinTest.java | 151 +++++++++++++++++++++
.../join/PushdownProjectThroughSemiJoinTest.java | 27 ++++
6 files changed, 313 insertions(+), 29 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 13c47df6a4..c439a13ce5 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -235,6 +235,7 @@ public enum RuleType {
LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION),
+ PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION),
// implementation rules
LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java
index b723e2e4a1..36c71cf01f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java
@@ -38,15 +38,6 @@ import java.util.stream.Stream;
* Common
*/
class JoinReorderUtils {
- /**
- * check project Expression Input Slot just contains one slot, like:
- * - one SlotReference like a.id
- * - Input Slot size == 1, like abs(a.id) + 1
- */
- static boolean isOneSlotProject(LogicalProject<LogicalJoin<GroupPlan,
GroupPlan>> project) {
- return project.getProjects().stream().allMatch(expr ->
expr.getInputSlotExprIds().size() == 1);
- }
-
static boolean isAllSlotProject(LogicalProject<LogicalJoin<GroupPlan,
GroupPlan>> project) {
return project.getProjects().stream().allMatch(expr -> expr instanceof
Slot);
}
@@ -78,6 +69,13 @@ class JoinReorderUtils {
return new LogicalProject<>(projectExprs, plan);
}
+ public static Plan projectOrSelfInOrder(List<NamedExpression>
projectExprs, Plan plan) {
+ if (projectExprs.isEmpty() || projectExprs.equals(plan.getOutput())) {
+ return plan;
+ }
+ return new LogicalProject<>(projectExprs, plan);
+ }
+
/**
* replace JoinConjuncts by using slots map.
*/
@@ -111,4 +109,11 @@ class JoinReorderUtils {
}
});
}
+
+ public static Set<Slot> joinChildConditionSlots(LogicalJoin<? extends
Plan, ? extends Plan> join, boolean left) {
+ Set<Slot> childSlots = left ? join.left().getOutputSet() :
join.right().getOutputSet();
+ return join.getConditionSlot().stream()
+ .filter(childSlots::contains)
+ .collect(Collectors.toSet());
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java
new file mode 100644
index 0000000000..c46f2252a2
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java
@@ -0,0 +1,104 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration.join;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
+import org.apache.doris.nereids.trees.expressions.ExprId;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableList.Builder;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * rule for pushdown project through inner/outer join
+ */
+public class PushdownProjectThroughInnerJoin extends OneExplorationRuleFactory
{
+ public static final PushdownProjectThroughInnerJoin INSTANCE = new
PushdownProjectThroughInnerJoin();
+
+ /*
+ * Project Join
+ * | ──► / \
+ * Join Project Project
+ * / \ | |
+ * A B A B
+ */
+ @Override
+ public Rule build() {
+ return logicalProject(logicalJoin())
+ .when(project -> project.child().getJoinType().isInnerJoin())
+ .whenNot(project -> project.child().hasJoinHint())
+ .then(project -> {
+ LogicalJoin<GroupPlan, GroupPlan> join = project.child();
+ Set<ExprId> aOutputExprIdSet =
join.left().getOutputExprIdSet();
+ Set<ExprId> bOutputExprIdSet =
join.right().getOutputExprIdSet();
+
+ // reject hyper edge in Project.
+ if (!project.getProjects().stream().allMatch(expr -> {
+ Set<ExprId> inputSlotExprIds = expr.getInputSlotExprIds();
+ return aOutputExprIdSet.containsAll(inputSlotExprIds)
+ || bOutputExprIdSet.containsAll(inputSlotExprIds);
+ })) {
+ return null;
+ }
+
+ Map<Boolean, List<NamedExpression>> map =
JoinReorderUtils.splitProjection(project.getProjects(),
+ join.left());
+ List<NamedExpression> aProjects = map.get(true);
+ List<NamedExpression> bProjects = map.get(false);
+
+ boolean leftContains = aProjects.stream().anyMatch(e -> !(e
instanceof Slot));
+ boolean rightContains = bProjects.stream().anyMatch(e -> !(e
instanceof Slot));
+ // due to JoinCommute, we don't need to consider just right
contains.
+ if (!leftContains) {
+ return null;
+ }
+
+ Builder<NamedExpression> newAProject =
ImmutableList.<NamedExpression>builder().addAll(aProjects);
+ Set<Slot> aConditionSlots =
JoinReorderUtils.joinChildConditionSlots(join, true);
+ Set<Slot> aProjectSlots =
aProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet());
+ aConditionSlots.stream().filter(slot ->
!aProjectSlots.contains(slot)).forEach(newAProject::add);
+ Plan newLeft =
JoinReorderUtils.projectOrSelf(newAProject.build(), join.left());
+
+ if (!rightContains) {
+ Plan newJoin = join.withChildren(newLeft, join.right());
+ return JoinReorderUtils.projectOrSelf(new
ArrayList<>(project.getOutput()), newJoin);
+ }
+
+ Builder<NamedExpression> newBProject =
ImmutableList.<NamedExpression>builder().addAll(bProjects);
+ Set<Slot> bConditionSlots =
JoinReorderUtils.joinChildConditionSlots(join, false);
+ Set<Slot> bProjectSlots =
bProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet());
+ bConditionSlots.stream().filter(slot ->
!bProjectSlots.contains(slot)).forEach(newBProject::add);
+ Plan newRight =
JoinReorderUtils.projectOrSelf(newBProject.build(), join.right());
+
+ Plan newJoin = join.withChildren(newLeft, newRight);
+ return JoinReorderUtils.projectOrSelfInOrder(new
ArrayList<>(project.getOutput()), newJoin);
+ }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN);
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java
index 57ea9df15d..172de009a7 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java
@@ -49,27 +49,23 @@ public class PushdownProjectThroughSemiJoin extends
OneExplorationRuleFactory {
@Override
public Rule build() {
return logicalProject(logicalJoin())
- .when(project ->
project.child().getJoinType().isLeftSemiOrAntiJoin())
- .when(JoinReorderUtils::isOneSlotProject)
- // Just pushdown project with non-column expr like (t.id + 1)
- .whenNot(JoinReorderUtils::isAllSlotProject)
- .whenNot(project -> project.child().hasJoinHint())
- .then(project -> {
- LogicalJoin<GroupPlan, GroupPlan> join = project.child();
- Set<Slot> aOutputExprIdSet = join.left().getOutputSet();
- Set<Slot> conditionLeftSlots =
join.getConditionSlot().stream()
- .filter(aOutputExprIdSet::contains)
- .collect(Collectors.toSet());
+ .when(project ->
project.child().getJoinType().isLeftSemiOrAntiJoin())
+ // Just pushdown project with non-column expr like (t.id + 1)
+ .whenNot(JoinReorderUtils::isAllSlotProject)
+ .whenNot(project -> project.child().hasJoinHint())
+ .then(project -> {
+ LogicalJoin<GroupPlan, GroupPlan> join = project.child();
+ Set<Slot> conditionLeftSlots =
JoinReorderUtils.joinChildConditionSlots(join, true);
- List<NamedExpression> newProject = new
ArrayList<>(project.getProjects());
- Set<Slot> projectUsedSlots = project.getProjects().stream()
-
.map(NamedExpression::toSlot).collect(Collectors.toSet());
- conditionLeftSlots.stream().filter(slot ->
!projectUsedSlots.contains(slot))
- .forEach(newProject::add);
- Plan newLeft = JoinReorderUtils.projectOrSelf(newProject,
join.left());
- Plan newJoin = join.withChildren(newLeft, join.right());
- return JoinReorderUtils.projectOrSelf(new
ArrayList<>(project.getOutput()), newJoin);
- }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN);
+ List<NamedExpression> newProject = new
ArrayList<>(project.getProjects());
+ Set<Slot> projectUsedSlots =
project.getProjects().stream().map(NamedExpression::toSlot)
+ .collect(Collectors.toSet());
+ conditionLeftSlots.stream().filter(slot ->
!projectUsedSlots.contains(slot)).forEach(newProject::add);
+ Plan newLeft = JoinReorderUtils.projectOrSelf(newProject,
join.left());
+
+ Plan newJoin = join.withChildren(newLeft, join.right());
+ return JoinReorderUtils.projectOrSelf(new
ArrayList<>(project.getOutput()), newJoin);
+ }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN);
}
List<NamedExpression> sort(List<NamedExpression> projects, Plan sortPlan) {
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoinTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoinTest.java
new file mode 100644
index 0000000000..7d94fa876c
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoinTest.java
@@ -0,0 +1,151 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration.join;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+
+class PushdownProjectThroughInnerJoinTest implements MemoPatternMatchSupported
{
+ private final LogicalOlapScan scan1 =
PlanConstructor.newLogicalOlapScan(0, "t1", 0);
+ private final LogicalOlapScan scan2 =
PlanConstructor.newLogicalOlapScan(1, "t2", 0);
+
+ @Test
+ public void pushBothSide() {
+ // project (t1.id + 1) as alias, t1.name, (t2.id + 1) as alias, t2.name
+ List<NamedExpression> projectExprs = ImmutableList.of(
+ new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)),
"alias"),
+ scan1.getOutput().get(1),
+ new Alias(new Add(scan2.getOutput().get(0), Literal.of(1)),
"alias"),
+ scan2.getOutput().get(1)
+ );
+ // complex projection contain ti.id, which isn't in Join Condition
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(1, 1))
+ .projectExprs(projectExprs)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
+ .printlnOrigin()
+ .printlnExploration()
+ .matchesExploration(
+ logicalJoin(
+ logicalProject().when(project ->
project.getProjects().size() == 2),
+ logicalProject().when(project ->
project.getProjects().size() == 2)
+ )
+ );
+ }
+
+ @Test
+ public void pushdownProjectInCondition() {
+ // project (t1.id + 1) as alias, t1.name, (t2.id + 1) as alias, t2.name
+ List<NamedExpression> projectExprs = ImmutableList.of(
+ new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)),
"alias"),
+ scan1.getOutput().get(1),
+ new Alias(new Add(scan2.getOutput().get(0), Literal.of(1)),
"alias"),
+ scan2.getOutput().get(1)
+ );
+ // complex projection contain ti.id, which is in Join Condition
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .projectExprs(projectExprs)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
+ .printlnOrigin()
+ .printlnExploration()
+ .matchesExploration(
+ logicalProject(
+ logicalJoin(
+ logicalProject().when(project ->
project.getProjects().size() == 3),
+ logicalProject().when(project ->
project.getProjects().size() == 3)
+ )
+ )
+ );
+ }
+
+ @Test
+ void pushComplexProject() {
+ // project (t1.id + t1.name) as complex1, (t2.id + t2.name) as complex2
+ List<NamedExpression> projectExprs = ImmutableList.of(
+ new Alias(new Add(scan1.getOutput().get(0),
scan1.getOutput().get(1)), "complex1"),
+ new Alias(new Add(scan2.getOutput().get(0),
scan2.getOutput().get(1)), "complex2")
+ );
+ // complex projection contain ti.id, which is in Join Condition
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .projectExprs(projectExprs)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
+ .printlnOrigin()
+ .printlnExploration()
+ .matchesExploration(
+ logicalProject(
+ logicalJoin(
+ logicalProject()
+ .when(project ->
+
project.getProjects().get(0).toSql().equals("(id + name) AS `complex1`")
+ &&
project.getProjects().get(1).toSql().equals("id")),
+ logicalProject()
+ .when(project ->
+
project.getProjects().get(0).toSql().equals("(id + name) AS `complex2`")
+ &&
project.getProjects().get(1).toSql().equals("id"))
+ )
+ ).when(project ->
project.getProjects().get(0).toSql().equals("complex1")
+ &&
project.getProjects().get(1).toSql().equals("complex2")
+ )
+ );
+ }
+
+ @Test
+ void rejectHyperEdgeProject() {
+ // project (t1.id + t2.id) as alias
+ List<NamedExpression> projectExprs = ImmutableList.of(
+ new Alias(new Add(scan1.getOutput().get(0),
scan2.getOutput().get(0)), "alias")
+ );
+ // complex projection contain ti.id, which is in Join Condition
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .projectExprs(projectExprs)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
+ .checkMemo(memo -> Assertions.assertEquals(1,
memo.getRoot().getLogicalExpressions().size()));
+ }
+}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java
index 0a3b7a04ba..b47910f748 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java
@@ -95,4 +95,31 @@ class PushdownProjectThroughSemiJoinTest implements
MemoPatternMatchSupported {
).when(project -> project.getProjects().size() == 2)
);
}
+
+ @Test
+ void pushComplexProject() {
+ // project (t1.id + t1.name) as complex
+ List<NamedExpression> projectExprs = ImmutableList.of(
+ new Alias(new Add(scan1.getOutput().get(0),
scan1.getOutput().get(1)), "complex"));
+ // complex projection contain ti.id, which is in Join Condition
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
+ .projectExprs(projectExprs)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+
.applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.build())
+ .printlnOrigin()
+ .printlnExploration()
+ .matchesExploration(
+ logicalProject(
+ leftSemiLogicalJoin(
+ logicalProject()
+ .when(project ->
project.getProjects().get(0).toSql().equals("(id + name) AS `complex`")
+ &&
project.getProjects().get(1).toSql().equals("id")),
+ logicalOlapScan()
+ )
+ ).when(project ->
project.getProjects().get(0).toSql().equals("complex"))
+ );
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]