This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 858e8234d7 [feature](Nereids) add predicates push down on all join
type (#12571)
858e8234d7 is described below
commit 858e8234d742a60729df6c7a90f526d58ef0cc12
Author: morrySnow <[email protected]>
AuthorDate: Thu Sep 15 15:18:42 2022 +0800
[feature](Nereids) add predicates push down on all join type (#12571)
* [feature](Nereids) add predicates push down on all join type
---
.../doris/nereids/jobs/batch/RewriteJob.java | 17 +-
.../org/apache/doris/nereids/rules/RuleSet.java | 17 +-
.../org/apache/doris/nereids/rules/RuleType.java | 1 +
.../logical/PushDownJoinOtherCondition.java | 99 +++++++++
...ughJoin.java => PushPredicatesThroughJoin.java} | 91 +++++---
.../logical/FindHashConditionForJoinTest.java | 4 +-
.../rules/rewrite/logical/LimitPushDownTest.java | 9 +-
.../logical/PruneOlapScanPartitionTest.java | 8 +-
.../logical/PushDownJoinOtherConditionTest.java | 196 ++++++++++++++++++
.../rewrite/logical/PushDownPredicateTest.java | 228 ---------------------
.../logical/PushPredicateThroughJoinTest.java | 208 +++++++++++++++++++
.../apache/doris/nereids/util/PlanConstructor.java | 8 +-
12 files changed, 594 insertions(+), 292 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/RewriteJob.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/RewriteJob.java
index 242687ba6d..245095abc8 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/RewriteJob.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/RewriteJob.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.Job;
+import org.apache.doris.nereids.rules.RuleSet;
import
org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
import org.apache.doris.nereids.rules.mv.SelectRollup;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
@@ -27,14 +28,8 @@ import
org.apache.doris.nereids.rules.rewrite.logical.EliminateFilter;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit;
import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin;
import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown;
-import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters;
-import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits;
-import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate;
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition;
-import org.apache.doris.nereids.rules.rewrite.logical.PushPredicateThroughJoin;
-import
org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject;
-import
org.apache.doris.nereids.rules.rewrite.logical.PushdownProjectThroughLimit;
import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
import com.google.common.collect.ImmutableList;
@@ -64,15 +59,9 @@ public class RewriteJob extends BatchRulesJob {
.add(topDownBatch(ImmutableList.of(new
ExpressionNormalization())))
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
.add(topDownBatch(ImmutableList.of(new ReorderJoin())))
- .add(topDownBatch(ImmutableList.of(new
FindHashConditionForJoin())))
- .add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
.add(topDownBatch(ImmutableList.of(new ColumnPruning())))
- .add(topDownBatch(ImmutableList.of(new
PushPredicateThroughJoin(),
- new PushdownProjectThroughLimit(),
- new PushdownFilterThroughProject(),
- new MergeConsecutiveProjects(),
- new MergeConsecutiveFilters(),
- new MergeConsecutiveLimits())))
+ .add(topDownBatch(RuleSet.PUSH_DOWN_JOIN_CONDITION_RULES))
+ .add(topDownBatch(ImmutableList.of(new
FindHashConditionForJoin())))
.add(topDownBatch(ImmutableList.of(new
AggregateDisassemble())))
.add(topDownBatch(ImmutableList.of(new LimitPushDown())))
.add(topDownBatch(ImmutableList.of(new EliminateLimit())))
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
index fa480135ce..cf2c0ce311 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
@@ -33,9 +33,13 @@ import
org.apache.doris.nereids.rules.implementation.LogicalOneRowRelationToPhys
import
org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalProject;
import
org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort;
import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
-import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
+import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters;
+import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
+import
org.apache.doris.nereids.rules.rewrite.logical.PushDownJoinOtherCondition;
+import
org.apache.doris.nereids.rules.rewrite.logical.PushPredicatesThroughJoin;
import
org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject;
+import
org.apache.doris.nereids.rules.rewrite.logical.PushdownProjectThroughLimit;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
@@ -55,9 +59,14 @@ public class RuleSet {
.add(new MergeConsecutiveProjects())
.build();
- public static final List<Rule> REWRITE_RULES = planRuleFactories()
- .add(new AggregateDisassemble())
- .build();
+ public static final List<RuleFactory> PUSH_DOWN_JOIN_CONDITION_RULES =
ImmutableList.of(
+ new PushDownJoinOtherCondition(),
+ new PushPredicatesThroughJoin(),
+ new PushdownProjectThroughLimit(),
+ new PushdownFilterThroughProject(),
+ new MergeConsecutiveProjects(),
+ new MergeConsecutiveFilters(),
+ new MergeConsecutiveLimits());
public static final List<Rule> IMPLEMENTATION_RULES = planRuleFactories()
.add(new LogicalAggToPhysicalHashAgg())
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 826acf522f..c9fe816970 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -79,6 +79,7 @@ public enum RuleType {
EXISTS_APPLY_TO_JOIN(RuleTypeClass.REWRITE),
// predicate push down rules
PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE),
+ PUSH_DOWN_JOIN_OTHER_CONDITION(RuleTypeClass.REWRITE),
PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION(RuleTypeClass.REWRITE),
// column prune rules,
COLUMN_PRUNE_AGGREGATION_CHILD(RuleTypeClass.REWRITE),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherCondition.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherCondition.java
new file mode 100644
index 0000000000..88744fbe65
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherCondition.java
@@ -0,0 +1,99 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.PlanUtils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+
+import java.util.List;
+import java.util.Set;
+
+/**
+ * Push the other join conditions in LogicalJoin to children.
+ */
+public class PushDownJoinOtherCondition extends OneRewriteRuleFactory {
+ private static final ImmutableList<JoinType> PUSH_DOWN_LEFT_VALID_TYPE =
ImmutableList.of(
+ JoinType.INNER_JOIN,
+ JoinType.LEFT_SEMI_JOIN,
+ JoinType.RIGHT_OUTER_JOIN,
+ JoinType.RIGHT_ANTI_JOIN,
+ JoinType.RIGHT_SEMI_JOIN,
+ JoinType.CROSS_JOIN
+ );
+
+ private static final ImmutableList<JoinType> PUSH_DOWN_RIGHT_VALID_TYPE =
ImmutableList.of(
+ JoinType.INNER_JOIN,
+ JoinType.LEFT_OUTER_JOIN,
+ JoinType.LEFT_ANTI_JOIN,
+ JoinType.LEFT_SEMI_JOIN,
+ JoinType.RIGHT_SEMI_JOIN,
+ JoinType.CROSS_JOIN
+ );
+
+ @Override
+ public Rule build() {
+ return logicalJoin().then(join -> {
+ if (!join.getOtherJoinCondition().isPresent()) {
+ return null;
+ }
+ List<Expression> otherConjuncts =
ExpressionUtils.extractConjunction(join.getOtherJoinCondition().get());
+ List<Expression> leftConjuncts = Lists.newArrayList();
+ List<Expression> rightConjuncts = Lists.newArrayList();
+
+ for (Expression otherConjunct : otherConjuncts) {
+ if (PUSH_DOWN_LEFT_VALID_TYPE.contains(join.getJoinType())
+ && allCoveredBy(otherConjunct,
join.left().getOutputSet())) {
+ leftConjuncts.add(otherConjunct);
+ }
+ if (PUSH_DOWN_RIGHT_VALID_TYPE.contains(join.getJoinType())
+ && allCoveredBy(otherConjunct,
join.right().getOutputSet())) {
+ rightConjuncts.add(otherConjunct);
+ }
+ }
+
+ if (leftConjuncts.isEmpty() && rightConjuncts.isEmpty()) {
+ return null;
+ }
+
+ otherConjuncts.removeAll(leftConjuncts);
+ otherConjuncts.removeAll(rightConjuncts);
+
+ Plan left = PlanUtils.filterOrSelf(leftConjuncts, join.left());
+ Plan right = PlanUtils.filterOrSelf(rightConjuncts, join.right());
+
+ return new LogicalJoin<>(join.getJoinType(),
join.getHashJoinConjuncts(),
+ ExpressionUtils.optionalAnd(otherConjuncts), left, right);
+
+ }).toRule(RuleType.PUSH_DOWN_JOIN_OTHER_CONDITION);
+ }
+
+ private boolean allCoveredBy(Expression predicate, Set<Slot> inputSlotSet)
{
+ return inputSlotSet.containsAll(predicate.getInputSlots());
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicatesThroughJoin.java
similarity index 61%
rename from
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java
rename to
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicatesThroughJoin.java
index 9859deb747..3cdfac4918 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicatesThroughJoin.java
@@ -21,15 +21,17 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.List;
@@ -37,17 +39,39 @@ import java.util.Objects;
import java.util.Set;
/**
- * Push the predicate in the LogicalFilter or LogicalJoin to the join children.
- * todo: Now, only support eq on condition for inner join, support other case
later
+ * Push the predicate in the LogicalFilter to the join children.
*/
-public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
+public class PushPredicatesThroughJoin extends OneRewriteRuleFactory {
+
+ private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_LEFT =
ImmutableList.of(
+ JoinType.INNER_JOIN,
+ JoinType.LEFT_OUTER_JOIN,
+ JoinType.LEFT_SEMI_JOIN,
+ JoinType.LEFT_ANTI_JOIN,
+ JoinType.CROSS_JOIN
+ );
+
+ private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_RIGHT =
ImmutableList.of(
+ JoinType.INNER_JOIN,
+ JoinType.RIGHT_OUTER_JOIN,
+ JoinType.RIGHT_SEMI_JOIN,
+ JoinType.RIGHT_ANTI_JOIN,
+ JoinType.CROSS_JOIN
+ );
+
+ private static final ImmutableList<JoinType> COULD_PUSH_EQUAL_TO =
ImmutableList.of(
+ JoinType.INNER_JOIN
+ );
+
/*
* For example:
- * select a.k1,b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5
where a.k1 > 1 and b.k1 > 2
+ * select a.k1, b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5
+ * where a.k1 > 1 and b.k1 > 2 and a.k2 > b.k2
+ *
* Logical plan tree:
* project
* |
- * filter (a.k1 > 1 and b.k1 > 2)
+ * filter (a.k1 > 1 and b.k1 > 2 and a.k2 > b.k2)
* |
* join (a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5)
* / \
@@ -55,69 +79,72 @@ public class PushPredicateThroughJoin extends
OneRewriteRuleFactory {
* transformed:
* project
* |
- * join (a.k1 = b.k1)
+ * filter(a.k2 > b.k2)
+ * |
+ * join (otherConditions: a.k1 = b.k1)
* / \
- * filter(a.k1 > 1 and a.k2 > 2 ) filter(b.k1 > 2 and b.k2 > 5)
+ * filter(a.k1 > 1 and a.k2 > 2) filter(b.k1 > 2 and b.k2 > 5)
* | |
* scan scan
*/
@Override
public Rule build() {
- return logicalFilter(innerLogicalJoin()).then(filter -> {
+ return logicalFilter(logicalJoin()).then(filter -> {
LogicalJoin<GroupPlan, GroupPlan> join = filter.child();
- Expression wherePredicates = filter.getPredicates();
- Expression onPredicates =
join.getOtherJoinCondition().orElse(BooleanLiteral.TRUE);
+ Expression filterPredicates = filter.getPredicates();
- List<Expression> otherConditions = Lists.newArrayList();
- List<Expression> eqConditions = Lists.newArrayList();
+ List<Expression> filterConditions = Lists.newArrayList();
+ List<Expression> joinConditions = Lists.newArrayList();
Set<Slot> leftInput = join.left().getOutputSet();
Set<Slot> rightInput = join.right().getOutputSet();
-
ExpressionUtils.extractConjunction(ExpressionUtils.and(onPredicates,
wherePredicates))
+ ExpressionUtils.extractConjunction(filterPredicates)
.forEach(predicate -> {
- if (Objects.nonNull(getJoinCondition(predicate,
leftInput, rightInput))) {
- eqConditions.add(predicate);
+ if (Objects.nonNull(getJoinCondition(predicate,
leftInput, rightInput))
+ &&
COULD_PUSH_EQUAL_TO.contains(join.getJoinType())) {
+ joinConditions.add(predicate);
} else {
- otherConditions.add(predicate);
+ filterConditions.add(predicate);
}
});
List<Expression> leftPredicates = Lists.newArrayList();
List<Expression> rightPredicates = Lists.newArrayList();
- for (Expression p : otherConditions) {
+ for (Expression p : filterConditions) {
Set<Slot> slots = p.getInputSlots();
if (slots.isEmpty()) {
leftPredicates.add(p);
rightPredicates.add(p);
continue;
}
- if (leftInput.containsAll(slots)) {
+ if (leftInput.containsAll(slots) &&
COULD_PUSH_THROUGH_LEFT.contains(join.getJoinType())) {
leftPredicates.add(p);
}
- if (rightInput.containsAll(slots)) {
+ if (rightInput.containsAll(slots) &&
COULD_PUSH_THROUGH_RIGHT.contains(join.getJoinType())) {
rightPredicates.add(p);
}
}
- otherConditions.removeAll(leftPredicates);
- otherConditions.removeAll(rightPredicates);
- otherConditions.addAll(eqConditions);
+ filterConditions.removeAll(leftPredicates);
+ filterConditions.removeAll(rightPredicates);
+ join.getOtherJoinCondition().map(joinConditions::add);
- return pushDownPredicate(join, otherConditions, leftPredicates,
rightPredicates);
+ return PlanUtils.filterOrSelf(filterConditions,
+ pushDownPredicate(join, joinConditions, leftPredicates,
rightPredicates));
}).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_JOIN);
}
- private Plan pushDownPredicate(LogicalJoin<GroupPlan, GroupPlan> joinPlan,
+ private Plan pushDownPredicate(LogicalJoin<GroupPlan, GroupPlan> join,
List<Expression> joinConditions, List<Expression> leftPredicates,
List<Expression> rightPredicates) {
// todo expr should optimize again using expr rewrite
- Plan leftPlan = PlanUtils.filterOrSelf(leftPredicates,
joinPlan.left());
- Plan rightPlan = PlanUtils.filterOrSelf(rightPredicates,
joinPlan.right());
+ Plan leftPlan = PlanUtils.filterOrSelf(leftPredicates, join.left());
+ Plan rightPlan = PlanUtils.filterOrSelf(rightPredicates, join.right());
- return new LogicalJoin<>(joinPlan.getJoinType(),
joinPlan.getHashJoinConjuncts(),
+ return new LogicalJoin<>(join.getJoinType(),
join.getHashJoinConjuncts(),
ExpressionUtils.optionalAnd(joinConditions), leftPlan,
rightPlan);
}
@@ -128,13 +155,13 @@ public class PushPredicateThroughJoin extends
OneRewriteRuleFactory {
ComparisonPredicate comparison = (ComparisonPredicate) predicate;
- Set<Slot> leftSlots = comparison.left().getInputSlots();
- Set<Slot> rightSlots = comparison.right().getInputSlots();
-
- if (!(leftSlots.size() >= 1 && rightSlots.size() >= 1)) {
+ if (!(comparison instanceof EqualTo)) {
return null;
}
+ Set<Slot> leftSlots = comparison.left().getInputSlots();
+ Set<Slot> rightSlots = comparison.right().getInputSlots();
+
if ((leftOutputs.containsAll(leftSlots) &&
rightOutputs.containsAll(rightSlots))
|| (leftOutputs.containsAll(rightSlots) &&
rightOutputs.containsAll(leftSlots))) {
return predicate;
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java
index 435942e562..025f2a39eb 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java
@@ -57,8 +57,8 @@ import java.util.Optional;
class FindHashConditionForJoinTest {
@Test
public void testFindHashCondition() {
- Plan student = new LogicalOlapScan(PlanConstructor.getNextId(),
PlanConstructor.student, ImmutableList.of(""));
- Plan score = new LogicalOlapScan(PlanConstructor.getNextId(),
PlanConstructor.score, ImmutableList.of(""));
+ Plan student = new
LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student,
ImmutableList.of(""));
+ Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
PlanConstructor.score, ImmutableList.of(""));
Slot studentId = student.getOutput().get(0);
Slot gender = student.getOutput().get(1);
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/LimitPushDownTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/LimitPushDownTest.java
index 714b5511a6..f71417e469 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/LimitPushDownTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/LimitPushDownTest.java
@@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
@@ -48,8 +49,8 @@ import java.util.function.Function;
import java.util.stream.Collectors;
class LimitPushDownTest extends TestWithFeService implements
PatternMatchSupported {
- private Plan scanScore = new LogicalOlapScan(PlanConstructor.score);
- private Plan scanStudent = new LogicalOlapScan(PlanConstructor.student);
+ private Plan scanScore = new LogicalOlapScan(new RelationId(0),
PlanConstructor.score);
+ private Plan scanStudent = new LogicalOlapScan(new RelationId(1),
PlanConstructor.student);
@Override
protected void runBeforeAll() throws Exception {
@@ -213,8 +214,8 @@ class LimitPushDownTest extends TestWithFeService
implements PatternMatchSupport
joinType,
joinConditions,
Optional.empty(),
- new LogicalOlapScan(PlanConstructor.score),
- new LogicalOlapScan(PlanConstructor.student)
+ new LogicalOlapScan(new RelationId(0), PlanConstructor.score),
+ new LogicalOlapScan(new RelationId(1), PlanConstructor.student)
);
if (hasProject) {
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java
index b6bd36b1f5..ca52fa2360 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java
@@ -88,7 +88,7 @@ class PruneOlapScanPartitionTest {
olapTable.getName();
result = "tbl";
}};
- LogicalOlapScan scan = new
LogicalOlapScan(PlanConstructor.getNextId(), olapTable);
+ LogicalOlapScan scan = new
LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
SlotReference slotRef = new SlotReference("col1",
IntegerType.INSTANCE);
Expression expression = new LessThan(slotRef, new IntegerLiteral(4));
LogicalFilter<LogicalOlapScan> filter = new
LogicalFilter<>(expression, scan);
@@ -104,7 +104,7 @@ class PruneOlapScanPartitionTest {
Expression greaterThan6 = new GreaterThan(slotRef, new
IntegerLiteral(6));
Or lessThan0OrGreaterThan6 = new Or(lessThan0, greaterThan6);
filter = new LogicalFilter<>(lessThan0OrGreaterThan6, scan);
- scan = new LogicalOlapScan(PlanConstructor.getNextId(), olapTable);
+ scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
olapTable);
cascadesContext = MemoTestUtils.createCascadesContext(filter);
rules = Lists.newArrayList(new PruneOlapScanPartition().build());
cascadesContext.topDownRewrite(rules);
@@ -118,7 +118,7 @@ class PruneOlapScanPartitionTest {
Expression lessThanEqual5 =
new LessThanEqual(slotRef, new IntegerLiteral(5));
And greaterThanEqual0AndLessThanEqual5 = new And(greaterThanEqual0,
lessThanEqual5);
- scan = new LogicalOlapScan(PlanConstructor.getNextId(), olapTable);
+ scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
olapTable);
filter = new LogicalFilter<>(greaterThanEqual0AndLessThanEqual5, scan);
cascadesContext = MemoTestUtils.createCascadesContext(filter);
rules = Lists.newArrayList(new PruneOlapScanPartition().build());
@@ -153,7 +153,7 @@ class PruneOlapScanPartitionTest {
olapTable.getName();
result = "tbl";
}};
- LogicalOlapScan scan = new
LogicalOlapScan(PlanConstructor.getNextId(), olapTable);
+ LogicalOlapScan scan = new
LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
Expression left = new LessThan(new SlotReference("col1",
IntegerType.INSTANCE), new IntegerLiteral(4));
Expression right = new GreaterThan(new SlotReference("col2",
IntegerType.INSTANCE), new IntegerLiteral(11));
CompoundPredicate and = new And(left, right);
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherConditionTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherConditionTest.java
new file mode 100644
index 0000000000..f6f5d664fe
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherConditionTest.java
@@ -0,0 +1,196 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.memo.Group;
+import org.apache.doris.nereids.memo.Memo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.GreaterThan;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.PlanConstructor;
+import org.apache.doris.nereids.util.PlanRewriter;
+import org.apache.doris.qe.ConnectContext;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInstance;
+
+import java.util.Optional;
+
+@TestInstance(TestInstance.Lifecycle.PER_CLASS)
+public class PushDownJoinOtherConditionTest {
+
+ private Plan rStudent;
+ private Plan rScore;
+
+ /**
+ * ut before.
+ */
+ @BeforeAll
+ public final void beforeAll() {
+ rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
PlanConstructor.student, ImmutableList.of(""));
+ rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
PlanConstructor.score, ImmutableList.of(""));
+ }
+
+ @Test
+ public void oneSide() {
+ oneSide(JoinType.CROSS_JOIN, false);
+ oneSide(JoinType.INNER_JOIN, false);
+ oneSide(JoinType.LEFT_OUTER_JOIN, true);
+ oneSide(JoinType.LEFT_SEMI_JOIN, true);
+ oneSide(JoinType.LEFT_ANTI_JOIN, true);
+ oneSide(JoinType.RIGHT_OUTER_JOIN, false);
+ oneSide(JoinType.RIGHT_SEMI_JOIN, false);
+ oneSide(JoinType.RIGHT_ANTI_JOIN, false);
+ }
+
+ private void oneSide(JoinType joinType, boolean testRight) {
+
+ Expression pushSide1 = new GreaterThan(rStudent.getOutput().get(1),
Literal.of(18));
+ Expression pushSide2 = new GreaterThan(rStudent.getOutput().get(1),
Literal.of(50));
+ Expression condition = ExpressionUtils.and(pushSide1, pushSide2);
+
+ Plan left = rStudent;
+ Plan right = rScore;
+ if (testRight) {
+ left = rScore;
+ right = rStudent;
+ }
+
+ Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(),
Optional.of(condition), left, right);
+ Plan root = new LogicalProject<>(Lists.newArrayList(), join);
+
+ Memo memo = rewrite(root);
+ Group rootGroup = memo.getRoot();
+
+ Plan shouldJoin =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
+ Plan shouldFilter =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+ .child(0).getLogicalExpression().getPlan();
+ Plan shouldScan =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+ .child(1).getLogicalExpression().getPlan();
+ if (testRight) {
+ shouldFilter =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+ .child(1).getLogicalExpression().getPlan();
+ shouldScan =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+ .child(0).getLogicalExpression().getPlan();
+ }
+
+ Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
+ Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
+ Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
+ LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
+
+ Assertions.assertEquals(condition, actualFilter.getPredicates());
+ }
+
+ @Test
+ public void bothSideToBothSide() {
+ bothSideToBothSide(JoinType.CROSS_JOIN);
+ bothSideToBothSide(JoinType.INNER_JOIN);
+ bothSideToBothSide(JoinType.LEFT_SEMI_JOIN);
+ bothSideToBothSide(JoinType.RIGHT_SEMI_JOIN);
+ }
+
+ private void bothSideToBothSide(JoinType joinType) {
+
+ Expression leftSide = new GreaterThan(rStudent.getOutput().get(1),
Literal.of(18));
+ Expression rightSide = new GreaterThan(rScore.getOutput().get(2),
Literal.of(60));
+ Expression condition = ExpressionUtils.and(leftSide, rightSide);
+
+ Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(),
Optional.of(condition), rStudent, rScore);
+ Plan root = new LogicalProject<>(Lists.newArrayList(), join);
+
+ Memo memo = rewrite(root);
+ Group rootGroup = memo.getRoot();
+
+ Plan shouldJoin =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
+ Plan leftFilter =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+ .child(0).getLogicalExpression().getPlan();
+ Plan rightFilter =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+ .child(1).getLogicalExpression().getPlan();
+
+ Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
+ Assertions.assertTrue(leftFilter instanceof LogicalFilter);
+ Assertions.assertTrue(rightFilter instanceof LogicalFilter);
+ LogicalFilter<Plan> actualLeft = (LogicalFilter<Plan>) leftFilter;
+ LogicalFilter<Plan> actualRight = (LogicalFilter<Plan>) rightFilter;
+ Assertions.assertEquals(leftSide, actualLeft.getPredicates());
+ Assertions.assertEquals(rightSide, actualRight.getPredicates());
+ }
+
+ @Test
+ public void bothSideToOneSide() {
+ bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, true);
+ bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, true);
+ bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, false);
+ bothSideToOneSide(JoinType.RIGHT_ANTI_JOIN, false);
+ }
+
+ private void bothSideToOneSide(JoinType joinType, boolean testRight) {
+
+ Expression pushSide = new GreaterThan(rStudent.getOutput().get(1),
Literal.of(18));
+ Expression reserveSide = new GreaterThan(rScore.getOutput().get(2),
Literal.of(60));
+ Expression condition = ExpressionUtils.and(pushSide, reserveSide);
+
+ Plan left = rStudent;
+ Plan right = rScore;
+ if (testRight) {
+ left = rScore;
+ right = rStudent;
+ }
+
+ Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(),
Optional.of(condition), left, right);
+ Plan root = new LogicalProject<>(Lists.newArrayList(), join);
+
+ Memo memo = rewrite(root);
+ Group rootGroup = memo.getRoot();
+
+ Plan shouldJoin = rootGroup.getLogicalExpression()
+ .child(0).getLogicalExpression().getPlan();
+ Plan shouldFilter = rootGroup.getLogicalExpression()
+
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
+ Plan shouldScan = rootGroup.getLogicalExpression()
+
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
+ if (testRight) {
+ shouldFilter = rootGroup.getLogicalExpression()
+
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
+ shouldScan = rootGroup.getLogicalExpression()
+
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
+ }
+
+ Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
+ Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
+ Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
+ LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
+ Assertions.assertEquals(pushSide, actualFilter.getPredicates());
+ }
+
+ private Memo rewrite(Plan plan) {
+ return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new
PushDownJoinOtherCondition());
+ }
+}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateTest.java
deleted file mode 100644
index 3224045c2c..0000000000
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateTest.java
+++ /dev/null
@@ -1,228 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.rewrite.logical;
-
-import org.apache.doris.nereids.memo.Group;
-import org.apache.doris.nereids.memo.Memo;
-import
org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
-import org.apache.doris.nereids.trees.expressions.Add;
-import org.apache.doris.nereids.trees.expressions.And;
-import org.apache.doris.nereids.trees.expressions.Between;
-import org.apache.doris.nereids.trees.expressions.Cast;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.GreaterThan;
-import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
-import org.apache.doris.nereids.trees.expressions.LessThanEqual;
-import org.apache.doris.nereids.trees.expressions.Subtract;
-import org.apache.doris.nereids.trees.expressions.literal.Literal;
-import org.apache.doris.nereids.trees.plans.JoinType;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-import org.apache.doris.nereids.types.DoubleType;
-import org.apache.doris.nereids.types.StringType;
-import org.apache.doris.nereids.util.ExpressionUtils;
-import org.apache.doris.nereids.util.PlanConstructor;
-import org.apache.doris.nereids.util.PlanRewriter;
-import org.apache.doris.qe.ConnectContext;
-
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Assertions;
-import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.TestInstance;
-
-import java.util.ArrayList;
-import java.util.Optional;
-
-/**
- * plan rewrite ut.
- */
-@TestInstance(TestInstance.Lifecycle.PER_CLASS)
-public class PushDownPredicateTest {
-
- private Plan rStudent;
- private Plan rScore;
- private Plan rCourse;
-
- /**
- * ut before.
- */
- @BeforeAll
- public final void beforeAll() {
- rStudent = new LogicalOlapScan(PlanConstructor.getNextId(),
PlanConstructor.student, ImmutableList.of(""));
-
- rScore = new LogicalOlapScan(PlanConstructor.getNextId(),
PlanConstructor.score, ImmutableList.of(""));
-
- rCourse = new LogicalOlapScan(PlanConstructor.getNextId(),
PlanConstructor.course, ImmutableList.of(""));
- }
-
- @Test
- public void pushDownPredicateIntoScanTest1() {
- // select id,name,grade from student join score on student.id =
score.sid and student.id > 1
- // and score.cid > 2 where student.age > 18 and score.grade > 60
- Expression onCondition1 = new EqualTo(rStudent.getOutput().get(0),
rScore.getOutput().get(0));
- Expression onCondition2 = new GreaterThan(rStudent.getOutput().get(0),
Literal.of(1));
- Expression onCondition3 = new GreaterThan(rScore.getOutput().get(0),
Literal.of(2));
- Expression onCondition = ExpressionUtils.and(onCondition1,
onCondition2, onCondition3);
-
- Expression whereCondition1 = new
GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
- Expression whereCondition2 = new
GreaterThan(rScore.getOutput().get(2), Literal.of(60));
- Expression whereCondition = ExpressionUtils.and(whereCondition1,
whereCondition2);
-
- Plan join = new LogicalJoin(JoinType.INNER_JOIN, new ArrayList<>(),
Optional.of(onCondition), rStudent, rScore);
- Plan filter = new LogicalFilter(whereCondition, join);
-
- Plan root = new LogicalProject(
- Lists.newArrayList(rStudent.getOutput().get(1),
rCourse.getOutput().get(1), rScore.getOutput().get(2)),
- filter
- );
-
- Memo memo = rewrite(root);
-
- Group rootGroup = memo.getRoot();
-
- Plan op1 =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
- Plan op2 =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
- .getPlan();
- Plan op3 =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(1).getLogicalExpression()
- .getPlan();
-
- Assertions.assertTrue(op1 instanceof LogicalJoin);
- Assertions.assertTrue(op2 instanceof LogicalFilter);
- Assertions.assertTrue(op3 instanceof LogicalFilter);
- LogicalJoin join1 = (LogicalJoin) op1;
- LogicalFilter filter1 = (LogicalFilter) op2;
- LogicalFilter filter2 = (LogicalFilter) op3;
-
- Assertions.assertEquals(onCondition1,
join1.getOtherJoinCondition().get());
- Assertions.assertEquals(ExpressionUtils.and(onCondition2,
whereCondition1), filter1.getPredicates());
- Assertions.assertEquals(ExpressionUtils.and(onCondition3,
- new GreaterThan(rScore.getOutput().get(2), new
Cast(Literal.of(60), DoubleType.INSTANCE))),
- filter2.getPredicates());
- }
-
- @Test
- public void pushDownPredicateIntoScanTest3() {
- //select id,name,grade from student left join score on student.id + 1
= score.sid - 2
- //where student.age > 18 and score.grade > 60
- Expression whereCondition1 = new EqualTo(new
Add(rStudent.getOutput().get(0), Literal.of(1)),
- new Subtract(rScore.getOutput().get(0), Literal.of(2)));
- Expression whereCondition2 = new
GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
- Expression whereCondition3 = new
GreaterThan(rScore.getOutput().get(2), Literal.of(60));
- Expression whereCondition = ExpressionUtils.and(whereCondition1,
whereCondition2, whereCondition3);
-
- Plan join = new LogicalJoin(JoinType.INNER_JOIN, new ArrayList<>(),
Optional.empty(), rStudent, rScore);
- Plan filter = new LogicalFilter(whereCondition, join);
-
- Plan root = new LogicalProject(
- Lists.newArrayList(rStudent.getOutput().get(1),
rCourse.getOutput().get(1), rScore.getOutput().get(2)),
- filter
- );
-
- Memo memo = rewrite(root);
- Group rootGroup = memo.getRoot();
-
- Plan op1 =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
- Plan op2 =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
- .getPlan();
- Plan op3 =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(1).getLogicalExpression()
- .getPlan();
-
- Assertions.assertTrue(op1 instanceof LogicalJoin);
- Assertions.assertTrue(op2 instanceof LogicalFilter);
- Assertions.assertTrue(op3 instanceof LogicalFilter);
- LogicalJoin join1 = (LogicalJoin) op1;
- LogicalFilter filter1 = (LogicalFilter) op2;
- LogicalFilter filter2 = (LogicalFilter) op3;
- Assertions.assertEquals(whereCondition1,
join1.getOtherJoinCondition().get());
- Assertions.assertEquals(whereCondition2, filter1.getPredicates());
- Assertions.assertEquals(
- new GreaterThan(rScore.getOutput().get(2), new
Cast(Literal.of(60), DoubleType.INSTANCE)),
- filter2.getPredicates());
- }
-
- @Test
- public void pushDownPredicateIntoScanTest4() {
- /*
- select
- student.name,
- course.name,
- score.grade
- from student,score,course
- where on student.id = score.sid and student.age between 18 and 20 and
score.grade > 60 and student.id = score.sid
- */
-
- // student.id = score.sid
- Expression whereCondition1 = new EqualTo(rStudent.getOutput().get(0),
rScore.getOutput().get(0));
- // score.cid = course.cid
- Expression whereCondition2 = new EqualTo(rScore.getOutput().get(1),
rCourse.getOutput().get(0));
- // student.age between 18 and 20
- Expression whereCondition3 = new Between(rStudent.getOutput().get(2),
Literal.of(18), Literal.of(20));
- // student.age >= 18 and student.age <= 20
- Expression whereCondition3result = new And(
- new GreaterThanEqual(rStudent.getOutput().get(2), new
Cast(Literal.of(18), StringType.INSTANCE)),
- new LessThanEqual(rStudent.getOutput().get(2), new
Cast(Literal.of(20), StringType.INSTANCE)));
-
- // score.grade > 60
- Expression whereCondition4 = new
GreaterThan(rScore.getOutput().get(2), Literal.of(60));
-
- Expression whereCondition = ExpressionUtils.and(whereCondition1,
whereCondition2, whereCondition3,
- whereCondition4);
-
- Plan join = new LogicalJoin(JoinType.INNER_JOIN, ImmutableList.of(),
Optional.empty(), rStudent, rScore);
- Plan join1 = new LogicalJoin(JoinType.INNER_JOIN, ImmutableList.of(),
Optional.empty(), join, rCourse);
- Plan filter = new LogicalFilter(whereCondition, join1);
-
- Plan root = new LogicalProject(
- Lists.newArrayList(rStudent.getOutput().get(1),
rCourse.getOutput().get(1), rScore.getOutput().get(2)),
- filter
- );
-
- Memo memo = rewrite(root);
- Group rootGroup = memo.getRoot();
- Plan join2 =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
- Plan join3 =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
- .getPlan();
- Plan op1 =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
- .child(0).getLogicalExpression().getPlan();
- Plan op2 =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
- .child(1).getLogicalExpression().getPlan();
-
- Assertions.assertTrue(join2 instanceof LogicalJoin);
- Assertions.assertTrue(join3 instanceof LogicalJoin);
- Assertions.assertTrue(op1 instanceof LogicalFilter);
- Assertions.assertTrue(op2 instanceof LogicalFilter);
-
- Assertions.assertEquals(whereCondition2, ((LogicalJoin)
join2).getOtherJoinCondition().get());
- Assertions.assertEquals(whereCondition1, ((LogicalJoin)
join3).getOtherJoinCondition().get());
- Assertions.assertEquals(whereCondition3result.toSql(),
((LogicalFilter) op1).getPredicates().toSql());
- Assertions.assertEquals(
- new GreaterThan(rScore.getOutput().get(2), new
Cast(Literal.of(60), DoubleType.INSTANCE)),
- ((LogicalFilter) op2).getPredicates());
- }
-
- private Memo rewrite(Plan plan) {
- Plan normalizedPlan = PlanRewriter.topDownRewrite(plan, new
ConnectContext(), new ExpressionNormalization());
- return PlanRewriter.topDownRewriteMemo(normalizedPlan, new
ConnectContext(), new PushPredicateThroughJoin());
- }
-}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoinTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoinTest.java
new file mode 100644
index 0000000000..3374613f73
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoinTest.java
@@ -0,0 +1,208 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.memo.Group;
+import org.apache.doris.nereids.memo.Memo;
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.GreaterThan;
+import org.apache.doris.nereids.trees.expressions.Subtract;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.PlanConstructor;
+import org.apache.doris.nereids.util.PlanRewriter;
+import org.apache.doris.qe.ConnectContext;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInstance;
+
+import java.util.Optional;
+
+/**
+ * plan rewrite ut.
+ */
+@TestInstance(TestInstance.Lifecycle.PER_CLASS)
+public class PushPredicateThroughJoinTest {
+
+ private Plan rStudent;
+ private Plan rScore;
+
+ /**
+ * ut before.
+ */
+ @BeforeAll
+ public final void beforeAll() {
+ rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
PlanConstructor.student, ImmutableList.of(""));
+ rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
PlanConstructor.score, ImmutableList.of(""));
+ }
+
+ @Test
+ public void oneSide() {
+ oneSide(JoinType.CROSS_JOIN, false);
+ oneSide(JoinType.INNER_JOIN, false);
+ oneSide(JoinType.LEFT_OUTER_JOIN, false);
+ oneSide(JoinType.LEFT_SEMI_JOIN, false);
+ oneSide(JoinType.LEFT_ANTI_JOIN, false);
+ oneSide(JoinType.RIGHT_OUTER_JOIN, true);
+ oneSide(JoinType.RIGHT_SEMI_JOIN, true);
+ oneSide(JoinType.RIGHT_ANTI_JOIN, true);
+ }
+
+ private void oneSide(JoinType joinType, boolean testRight) {
+
+ Expression whereCondition1 = new
GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
+ Expression whereCondition2 = new
GreaterThan(rStudent.getOutput().get(1), Literal.of(50));
+ Expression whereCondition = ExpressionUtils.and(whereCondition1,
whereCondition2);
+
+ Plan left = rStudent;
+ Plan right = rScore;
+ if (testRight) {
+ left = rScore;
+ right = rStudent;
+ }
+
+ Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(),
Optional.empty(), left, right);
+ Plan filter = new LogicalFilter<>(whereCondition, join);
+ Plan root = new LogicalProject<>(Lists.newArrayList(), filter);
+
+ Memo memo = rewrite(root);
+ Group rootGroup = memo.getRoot();
+
+ Plan shouldJoin =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
+ Plan shouldFilter =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+ .child(0).getLogicalExpression().getPlan();
+ Plan shouldScan =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+ .child(1).getLogicalExpression().getPlan();
+ if (testRight) {
+ shouldFilter =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+ .child(1).getLogicalExpression().getPlan();
+ shouldScan =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+ .child(0).getLogicalExpression().getPlan();
+ }
+
+ Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
+ Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
+ Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
+ LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
+
+ Assertions.assertEquals(whereCondition, actualFilter.getPredicates());
+ }
+
+ @Test
+ public void bothSideToBothSide() {
+ bothSideToBothSide(JoinType.INNER_JOIN);
+ }
+
+ private void bothSideToBothSide(JoinType joinType) {
+
+ Expression bothSideEqualTo = new EqualTo(new
Add(rStudent.getOutput().get(0), Literal.of(1)),
+ new Subtract(rScore.getOutput().get(0), Literal.of(2)));
+ Expression leftSide = new GreaterThan(rStudent.getOutput().get(1),
Literal.of(18));
+ Expression rightSide = new GreaterThan(rScore.getOutput().get(2),
Literal.of(60));
+ Expression whereCondition = ExpressionUtils.and(bothSideEqualTo,
leftSide, rightSide);
+
+ Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(),
Optional.empty(), rStudent, rScore);
+ Plan filter = new LogicalFilter<>(whereCondition, join);
+ Plan root = new LogicalProject<>(Lists.newArrayList(), filter);
+
+ Memo memo = rewrite(root);
+ Group rootGroup = memo.getRoot();
+
+ Plan shouldJoin =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
+ Plan leftFilter =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+ .child(0).getLogicalExpression().getPlan();
+ Plan rightFilter =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+ .child(1).getLogicalExpression().getPlan();
+
+ Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
+ Assertions.assertTrue(leftFilter instanceof LogicalFilter);
+ Assertions.assertTrue(rightFilter instanceof LogicalFilter);
+ LogicalJoin<Plan, Plan> actualJoin = (LogicalJoin<Plan, Plan>)
shouldJoin;
+ LogicalFilter<Plan> actualLeft = (LogicalFilter<Plan>) leftFilter;
+ LogicalFilter<Plan> actualRight = (LogicalFilter<Plan>) rightFilter;
+ Assertions.assertEquals(bothSideEqualTo,
actualJoin.getOtherJoinCondition().get());
+ Assertions.assertEquals(leftSide, actualLeft.getPredicates());
+ Assertions.assertEquals(rightSide, actualRight.getPredicates());
+ }
+
+ @Test
+ public void bothSideToOneSide() {
+ bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, false);
+ bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, false);
+ bothSideToOneSide(JoinType.LEFT_SEMI_JOIN, false);
+ bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, true);
+ bothSideToOneSide(JoinType.RIGHT_ANTI_JOIN, true);
+ bothSideToOneSide(JoinType.RIGHT_SEMI_JOIN, true);
+ }
+
+ private void bothSideToOneSide(JoinType joinType, boolean testRight) {
+
+ Expression pushSide = new GreaterThan(rStudent.getOutput().get(1),
Literal.of(18));
+ Expression reserveSide = new GreaterThan(rScore.getOutput().get(2),
Literal.of(60));
+ Expression whereCondition = ExpressionUtils.and(pushSide, reserveSide);
+
+ Plan left = rStudent;
+ Plan right = rScore;
+ if (testRight) {
+ left = rScore;
+ right = rStudent;
+ }
+
+ Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(),
Optional.empty(), left, right);
+ Plan filter = new LogicalFilter<>(whereCondition, join);
+ Plan root = new LogicalProject<>(Lists.newArrayList(), filter);
+
+ Memo memo = rewrite(root);
+ Group rootGroup = memo.getRoot();
+
+ Plan shouldJoin =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+ .child(0).getLogicalExpression().getPlan();
+ Plan shouldFilter =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
+ Plan shouldScan =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
+ if (testRight) {
+ shouldFilter =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
+ shouldScan =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
+
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
+ }
+
+ Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
+ Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
+ Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
+ LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
+ Assertions.assertEquals(pushSide, actualFilter.getPredicates());
+ }
+
+ private Memo rewrite(Plan plan) {
+ return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new
PushPredicatesThroughJoin());
+ }
+}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanConstructor.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanConstructor.java
index 135aaaf419..c03d24445b 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanConstructor.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanConstructor.java
@@ -37,7 +37,7 @@ public class PlanConstructor {
public static OlapTable student;
public static OlapTable score;
public static OlapTable course;
- private static final IdGenerator<RelationId> GENERATOR =
RelationId.createGenerator();
+ private static final IdGenerator<RelationId> RELATION_ID_GENERATOR =
RelationId.createGenerator();
static {
student = new OlapTable(0L, "student",
@@ -102,14 +102,14 @@ public class PlanConstructor {
// With OlapTable.
// Warning: equals() of Table depends on tableId.
public static LogicalOlapScan newLogicalOlapScan(long tableId, String
tableName, int hashColumn) {
- return new LogicalOlapScan(GENERATOR.getNextId(),
newOlapTable(tableId, tableName, hashColumn), ImmutableList.of("db"));
+ return new LogicalOlapScan(RELATION_ID_GENERATOR.getNextId(),
newOlapTable(tableId, tableName, hashColumn), ImmutableList.of("db"));
}
public static LogicalOlapScan newLogicalOlapScanWithSameId(long tableId,
String tableName, int hashColumn) {
return new LogicalOlapScan(RelationId.createGenerator().getNextId(),
newOlapTable(tableId, tableName, hashColumn), ImmutableList.of("db"));
}
- public static RelationId getNextId() {
- return GENERATOR.getNextId();
+ public static RelationId getNextRelationId() {
+ return RELATION_ID_GENERATOR.getNextId();
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]