This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 11fbe07221 [refactor](Nereids) Refactor all rewrite logical unit tests
by match-pattern (#17691)
11fbe07221 is described below
commit 11fbe072213e1bb007b2e9d1b97217e22045df27
Author: Weijie Guo <[email protected]>
AuthorDate: Sun Mar 12 18:49:12 2023 +0800
[refactor](Nereids) Refactor all rewrite logical unit tests by
match-pattern (#17691)
---
.../logical/EliminateGroupByConstantTest.java | 196 ++++++++++-----------
.../logical/EliminateUnnecessaryProjectTest.java | 49 ++----
.../logical/FindHashConditionForJoinTest.java | 32 ++--
.../rules/rewrite/logical/MergeFiltersTest.java | 47 +++--
.../rules/rewrite/logical/MergeLimitsTest.java | 46 +++--
.../logical/PhysicalStorageLayerAggregateTest.java | 4 +-
.../logical/PruneOlapScanPartitionTest.java | 74 ++++----
.../rewrite/logical/PruneOlapScanTabletTest.java | 20 ++-
.../logical/PushdownJoinOtherConditionTest.java | 170 ++++++++----------
.../logical/PushdownProjectThroughLimitTest.java | 4 +-
.../rules/rewrite/logical/SplitLimitTest.java | 25 ++-
.../org/apache/doris/nereids/util/PlanChecker.java | 7 +
12 files changed, 323 insertions(+), 351 deletions(-)
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstantTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstantTest.java
index 9aa4a027a8..4af6f2234b 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstantTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstantTest.java
@@ -23,30 +23,29 @@ import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.PartitionInfo;
import org.apache.doris.catalog.Type;
-import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.plans.ObjectId;
-import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.thrift.TStorageType;
import com.google.common.collect.ImmutableList;
-import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
-import java.util.List;
-
-public class EliminateGroupByConstantTest {
+/** Tests for {@link EliminateGroupByConstant}. */
+class EliminateGroupByConstantTest implements MemoPatternMatchSupported {
private static final OlapTable table = new OlapTable(0L, "student",
ImmutableList.of(new Column("k1", Type.INT, true,
AggregateType.NONE, "0", ""),
new Column("k2", Type.INT, false, AggregateType.NONE, "0",
""),
@@ -65,110 +64,107 @@ public class EliminateGroupByConstantTest {
}
@Test
- public void testIntegerLiteral() {
- LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
- ImmutableList.of(new IntegerLiteral(1), k2),
- ImmutableList.of(k1, k2),
- new LogicalOlapScan(ObjectId.createGenerator().getNextId(),
table)
- );
-
- CascadesContext context =
MemoTestUtils.createCascadesContext(aggregate);
- context.topDownRewrite(new EliminateGroupByConstant().build());
- context.bottomUpRewrite(new CheckAfterRewrite().build());
-
- LogicalAggregate aggregate1 = ((LogicalAggregate)
context.getMemo().copyOut());
- Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1);
- Assertions.assertTrue(aggregate1.getGroupByExpressions().get(0)
instanceof Slot);
+ void testIntegerLiteral() {
+ LogicalPlan aggregate = new LogicalPlanBuilder(
+ new LogicalOlapScan(ObjectId.createGenerator().getNextId(),
table))
+ .agg(ImmutableList.of(new IntegerLiteral(1), k2),
+ ImmutableList.of(k1, k2))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
+ .applyTopDown(new EliminateGroupByConstant())
+ .applyBottomUp(new CheckAfterRewrite())
+ .matches(
+ aggregate().when(agg ->
agg.getGroupByExpressions().equals(ImmutableList.of(k2)))
+ );
}
@Test
- public void testOtherLiteral() {
- LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
- ImmutableList.of(
- new StringLiteral("str"), k2),
- ImmutableList.of(
- new Alias(new StringLiteral("str"), "str"), k1, k2),
- new LogicalOlapScan(ObjectId.createGenerator().getNextId(),
table)
- );
-
- CascadesContext context =
MemoTestUtils.createCascadesContext(aggregate);
- context.topDownRewrite(new EliminateGroupByConstant().build());
- context.bottomUpRewrite(new CheckAfterRewrite().build());
-
- LogicalAggregate aggregate1 = ((LogicalAggregate)
context.getMemo().copyOut());
- Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1);
- Assertions.assertTrue(aggregate1.getGroupByExpressions().get(0)
instanceof Slot);
+ void testOtherLiteral() {
+ LogicalPlan aggregate = new LogicalPlanBuilder(
+ new LogicalOlapScan(ObjectId.createGenerator().getNextId(),
table))
+ .agg(ImmutableList.of(
+ new StringLiteral("str"), k2),
+ ImmutableList.of(
+ new Alias(new StringLiteral("str"), "str"), k1,
k2))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
+ .applyTopDown(new EliminateGroupByConstant())
+ .applyBottomUp(new CheckAfterRewrite())
+ .matches(
+ aggregate().when(agg ->
agg.getGroupByExpressions().equals(ImmutableList.of(k2)))
+ );
}
@Test
- public void testMixedLiteral() {
- LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
- ImmutableList.of(
- new StringLiteral("str"), k2,
- new IntegerLiteral(1),
- new IntegerLiteral(2),
- new IntegerLiteral(3),
- new Add(k1, k2)),
- ImmutableList.of(
- new Alias(new StringLiteral("str"), "str"),
- k2, k1, new Alias(new IntegerLiteral(1), "integer")),
- new LogicalOlapScan(ObjectId.createGenerator().getNextId(),
table)
- );
-
- CascadesContext context =
MemoTestUtils.createCascadesContext(aggregate);
- context.topDownRewrite(new EliminateGroupByConstant().build());
- context.bottomUpRewrite(new CheckAfterRewrite().build());
-
- LogicalAggregate aggregate1 = ((LogicalAggregate)
context.getMemo().copyOut());
- Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 2);
- List groupByExprs = aggregate1.getGroupByExpressions();
- Assertions.assertTrue(groupByExprs.get(0) instanceof Slot
- && groupByExprs.get(1) instanceof Add);
+ void testMixedLiteral() {
+ LogicalPlan aggregate = new LogicalPlanBuilder(
+ new LogicalOlapScan(ObjectId.createGenerator().getNextId(),
table))
+ .agg(ImmutableList.of(
+ new StringLiteral("str"), k2,
+ new IntegerLiteral(1),
+ new IntegerLiteral(2),
+ new IntegerLiteral(3),
+ new Add(k1, k2)),
+ ImmutableList.of(
+ new Alias(new StringLiteral("str"), "str"),
+ k2, k1, new Alias(new IntegerLiteral(1),
"integer")))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
+ .applyTopDown(new EliminateGroupByConstant())
+ .applyBottomUp(new CheckAfterRewrite())
+ .matches(
+ aggregate()
+ .when(agg ->
agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2))))
+ );
}
@Test
- public void testComplexGroupBy() {
- LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
- ImmutableList.of(
- new IntegerLiteral(1),
- new IntegerLiteral(2),
- new Add(k1, k2)),
- ImmutableList.of(
- new Alias(new Max(k1), "max"),
- new Alias(new Min(k2), "min"),
- new Alias(new Add(k1, k2), "add")),
- new LogicalOlapScan(ObjectId.createGenerator().getNextId(),
table)
- );
-
- CascadesContext context =
MemoTestUtils.createCascadesContext(aggregate);
- context.topDownRewrite(new EliminateGroupByConstant().build());
- context.bottomUpRewrite(new CheckAfterRewrite().build());
-
- LogicalAggregate aggregate1 = ((LogicalAggregate)
context.getMemo().copyOut());
- Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1);
+ void testComplexGroupBy() {
+ LogicalPlan aggregate = new LogicalPlanBuilder(
+ new LogicalOlapScan(ObjectId.createGenerator().getNextId(),
table))
+ .agg(ImmutableList.of(
+ new IntegerLiteral(1),
+ new IntegerLiteral(2),
+ new Add(k1, k2)),
+ ImmutableList.of(
+ new Alias(new Max(k1), "max"),
+ new Alias(new Min(k2), "min"),
+ new Alias(new Add(k1, k2), "add")))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
+ .applyTopDown(new EliminateGroupByConstant())
+ .applyBottomUp(new CheckAfterRewrite())
+ .matches(
+ aggregate()
+ .when(agg ->
agg.getGroupByExpressions().equals(ImmutableList.of(new Add(k1, k2))))
+ );
}
@Test
- public void testOutOfRange() {
- LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
- ImmutableList.of(
- new StringLiteral("str"), k2,
- new IntegerLiteral(1),
- new IntegerLiteral(2),
- new IntegerLiteral(3),
- new IntegerLiteral(5),
- new Add(k1, k2)),
- ImmutableList.of(
- new Alias(new StringLiteral("str"), "str"),
- k2, k1, new Alias(new IntegerLiteral(1), "integer")),
- new LogicalOlapScan(ObjectId.createGenerator().getNextId(),
table)
- );
-
- CascadesContext context =
MemoTestUtils.createCascadesContext(aggregate);
- context.topDownRewrite(new EliminateGroupByConstant().build());
- context.bottomUpRewrite(new CheckAfterRewrite().build());
-
- LogicalAggregate aggregate1 = ((LogicalAggregate)
context.getMemo().copyOut());
- Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 2);
+ void testOutOfRange() {
+ LogicalPlan aggregate = new LogicalPlanBuilder(
+ new LogicalOlapScan(ObjectId.createGenerator().getNextId(),
table))
+ .agg(ImmutableList.of(
+ new StringLiteral("str"), k2,
+ new IntegerLiteral(1),
+ new IntegerLiteral(2),
+ new IntegerLiteral(3),
+ new IntegerLiteral(5),
+ new Add(k1, k2)),
+ ImmutableList.of(
+ new Alias(new StringLiteral("str"),
"str"),
+ k2, k1, new Alias(new IntegerLiteral(1),
"integer")))
+ .build();
+ PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
+ .applyTopDown(new EliminateGroupByConstant())
+ .applyBottomUp(new CheckAfterRewrite())
+ .matches(
+ aggregate()
+ .when(agg ->
agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2))))
+ );
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java
index ecd6de1cd5..4435eb16f8 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java
@@ -17,28 +17,25 @@
package org.apache.doris.nereids.rules.rewrite.logical;
-import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
-import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
-import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
/**
* test ELIMINATE_UNNECESSARY_PROJECT rule.
*/
-public class EliminateUnnecessaryProjectTest extends TestWithFeService {
+class EliminateUnnecessaryProjectTest extends TestWithFeService implements
MemoPatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
@@ -55,57 +52,49 @@ public class EliminateUnnecessaryProjectTest extends
TestWithFeService {
}
@Test
- public void testEliminateNonTopUnnecessaryProject() {
+ void testEliminateNonTopUnnecessaryProject() {
LogicalPlan unnecessaryProject = new
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
.project(ImmutableList.of(1, 0))
.filter(BooleanLiteral.FALSE)
.build();
- CascadesContext cascadesContext =
MemoTestUtils.createCascadesContext(unnecessaryProject);
- cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
-
- Plan actual = cascadesContext.getMemo().copyOut();
- Assertions.assertTrue(actual.child(0) instanceof LogicalProject);
+ PlanChecker.from(MemoTestUtils.createConnectContext(),
unnecessaryProject)
+ .applyTopDown(new EliminateUnnecessaryProject())
+ .matchesFromRoot(logicalFilter(logicalProject()));
}
@Test
- public void testEliminateTopUnnecessaryProject() {
+ void testEliminateTopUnnecessaryProject() {
LogicalPlan unnecessaryProject = new
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
.project(ImmutableList.of(0, 1))
.build();
- CascadesContext cascadesContext =
MemoTestUtils.createCascadesContext(unnecessaryProject);
- cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
-
- Plan actual = cascadesContext.getMemo().copyOut();
- Assertions.assertTrue(actual instanceof LogicalOlapScan);
+ PlanChecker.from(MemoTestUtils.createConnectContext(),
unnecessaryProject)
+ .applyTopDown(new EliminateUnnecessaryProject())
+ .matchesFromRoot(logicalOlapScan());
}
@Test
- public void testNotEliminateTopProjectWhenOutputNotEquals() {
+ void testNotEliminateTopProjectWhenOutputNotEquals() {
LogicalPlan necessaryProject = new
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
.project(ImmutableList.of(1, 0))
.build();
- CascadesContext cascadesContext =
MemoTestUtils.createCascadesContext(necessaryProject);
- cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
-
- Plan actual = cascadesContext.getMemo().copyOut();
- Assertions.assertTrue(actual instanceof LogicalProject);
+ PlanChecker.from(MemoTestUtils.createConnectContext(),
necessaryProject)
+ .applyTopDown(new EliminateUnnecessaryProject())
+ .matchesFromRoot(logicalProject());
}
@Test
- public void testEliminateProjectWhenEmptyRelationChild() {
+ void testEliminateProjectWhenEmptyRelationChild() {
LogicalPlan unnecessaryProject = new LogicalPlanBuilder(new
LogicalEmptyRelation(ImmutableList.of(
new SlotReference("k1", IntegerType.INSTANCE),
new SlotReference("k2", IntegerType.INSTANCE))))
.project(ImmutableList.of(1, 0))
.build();
- CascadesContext cascadesContext =
MemoTestUtils.createCascadesContext(unnecessaryProject);
- cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
-
- Plan actual = cascadesContext.getMemo().copyOut();
- Assertions.assertTrue(actual instanceof LogicalEmptyRelation);
+ PlanChecker.from(MemoTestUtils.createConnectContext(),
unnecessaryProject)
+ .applyTopDown(new EliminateUnnecessaryProject())
+ .matchesFromRoot(logicalEmptyRelation());
}
// TODO: uncomment this after the Elimination project rule is correctly
implemented
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java
index 020e1d79d6..8bf0bada3f 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java
@@ -17,8 +17,6 @@
package org.apache.doris.nereids.rules.rewrite.logical;
-import org.apache.doris.nereids.CascadesContext;
-import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
@@ -31,12 +29,12 @@ import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
+import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
@@ -54,9 +52,9 @@ import java.util.Optional;
* -hashJoinConjuncts={A.x=B.x, A.y+1=B.y}
* -otherJoinCondition="A.x=1 and (A.x=1 or B.x=A.x) and A.x>B.x"
*/
-class FindHashConditionForJoinTest {
+class FindHashConditionForJoinTest implements MemoPatternMatchSupported {
@Test
- public void testFindHashCondition() {
+ void testFindHashCondition() {
Plan student = new
LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student,
ImmutableList.of(""));
Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
PlanConstructor.score,
@@ -77,20 +75,12 @@ class FindHashConditionForJoinTest {
List<Expression> expr = ImmutableList.of(eq1, eq2, eq3, or, less);
LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, new
ArrayList<>(),
expr, JoinHint.NONE, Optional.empty(), student, score);
- CascadesContext context = MemoTestUtils.createCascadesContext(join);
- List<Rule> rules = Lists.newArrayList(new
FindHashConditionForJoin().build());
- context.topDownRewrite(rules);
- Plan plan = context.getMemo().copyOut();
- LogicalJoin after = (LogicalJoin) plan;
- Assertions.assertEquals(after.getHashJoinConjuncts().size(), 2);
- Assertions.assertTrue(after.getHashJoinConjuncts().contains(eq1));
- Assertions.assertTrue(after.getHashJoinConjuncts().contains(eq3));
- List<Expression> others = after.getOtherJoinConjuncts();
- Assertions.assertEquals(others.size(), 3);
- Assertions.assertTrue(others.contains(less));
- Assertions.assertTrue(others.contains(eq2));
- Assertions.assertTrue(others.contains(less));
+ PlanChecker.from(new ConnectContext(), join)
+ .applyTopDown(new FindHashConditionForJoin())
+ .matches(
+ logicalJoin()
+ .when(j ->
j.getHashJoinConjuncts().equals(ImmutableList.of(eq1, eq3)))
+ .when(j ->
j.getOtherJoinConjuncts().equals(ImmutableList.of(eq2, or, less))));
}
-
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java
index 31bf853bdf..e706151e0e 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java
@@ -17,47 +17,42 @@
package org.apache.doris.nereids.rules.rewrite.logical;
-import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.UnboundRelation;
-import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
-import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
-import java.util.List;
-import java.util.Set;
-
/**
- * MergeConsecutiveFilter ut
+ * Tests for {@link MergeFilters}.
*/
-public class MergeFiltersTest {
+class MergeFiltersTest implements MemoPatternMatchSupported {
@Test
- public void testMergeConsecutiveFilters() {
- UnboundRelation relation = new
UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db",
"table"));
+ void testMergeFilters() {
Expression expression1 = new IntegerLiteral(1);
- LogicalFilter filter1 = new
LogicalFilter<>(ImmutableSet.of(expression1), relation);
Expression expression2 = new IntegerLiteral(2);
- LogicalFilter filter2 = new
LogicalFilter<>(ImmutableSet.of(expression2), filter1);
Expression expression3 = new IntegerLiteral(3);
- LogicalFilter filter3 = new
LogicalFilter<>(ImmutableSet.of(expression3), filter2);
- CascadesContext cascadesContext =
MemoTestUtils.createCascadesContext(filter3);
- List<Rule> rules = Lists.newArrayList(new MergeFilters().build());
- cascadesContext.bottomUpRewrite(rules);
- //check transformed plan
- Plan resultPlan = cascadesContext.getMemo().copyOut();
- System.out.println(resultPlan.treeString());
- Assertions.assertTrue(resultPlan instanceof LogicalFilter);
- Set<Expression> allPredicates = ImmutableSet.of(expression1,
expression2, expression3);
- Assertions.assertEquals(ImmutableSet.copyOf(((LogicalFilter<?>)
resultPlan).getConjuncts()), allPredicates);
- Assertions.assertTrue(resultPlan.child(0) instanceof UnboundRelation);
+ LogicalPlan logicalFilter = new LogicalPlanBuilder(
+ new UnboundRelation(RelationUtil.newRelationId(),
Lists.newArrayList("db", "table")))
+ .filter(ImmutableSet.of(expression1))
+ .filter(ImmutableSet.of(expression2))
+ .filter(ImmutableSet.of(expression3))
+ .build();
+
+ PlanChecker.from(new ConnectContext(),
logicalFilter).applyBottomUp(new MergeFilters())
+ .matches(
+ logicalFilter(
+ unboundRelation()
+ ).when(filter -> filter.getConjuncts()
+ .equals(ImmutableSet.of(expression1,
expression2, expression3))));
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java
index 869dec982f..ae608f1d26 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java
@@ -17,37 +17,33 @@
package org.apache.doris.nereids.rules.rewrite.logical;
-import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.UnboundRelation;
-import org.apache.doris.nereids.rules.Rule;
-import org.apache.doris.nereids.trees.plans.LimitPhase;
-import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
-import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
-import java.util.List;
-
-public class MergeLimitsTest {
+/**
+ * Tests for {@link MergeLimits}.
+ */
+class MergeLimitsTest implements MemoPatternMatchSupported {
@Test
- public void testMergeConsecutiveLimits() {
- LogicalLimit limit3 = new LogicalLimit<>(3, 5, LimitPhase.ORIGIN, new
UnboundRelation(
- RelationUtil.newRelationId(), Lists.newArrayList("db", "t")));
- LogicalLimit limit2 = new LogicalLimit<>(2, 0, LimitPhase.ORIGIN,
limit3);
- LogicalLimit limit1 = new LogicalLimit<>(10, 0, LimitPhase.ORIGIN,
limit2);
-
- CascadesContext context = MemoTestUtils.createCascadesContext(limit1);
- List<Rule> rules = Lists.newArrayList(new MergeLimits().build());
- context.topDownRewrite(rules);
- LogicalLimit limit = (LogicalLimit) context.getMemo().copyOut();
-
- Assertions.assertEquals(2, limit.getLimit());
- Assertions.assertEquals(5, limit.getOffset());
- Assertions.assertEquals(1, limit.children().size());
- Assertions.assertTrue(limit.child(0) instanceof UnboundRelation);
-
+ void testMergeLimits() {
+ LogicalPlan logicalLimit = new LogicalPlanBuilder(
+ new UnboundRelation(RelationUtil.newRelationId(),
Lists.newArrayList("db", "t")))
+ .limit(3, 5)
+ .limit(2, 0)
+ .limit(10, 0).build();
+
+ PlanChecker.from(new ConnectContext(), logicalLimit).applyTopDown(new
MergeLimits())
+ .matches(
+ logicalLimit(
+ unboundRelation()
+ ).when(limit -> limit.getLimit() == 2).when(limit ->
limit.getOffset() == 5));
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java
index 1556438652..69e07eebc7 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java
@@ -18,7 +18,6 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.CascadesContext;
-import org.apache.doris.nereids.pattern.GeneratedMemoPatterns;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RulePromise;
import org.apache.doris.nereids.rules.RuleType;
@@ -32,6 +31,7 @@ import
org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import
org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate.PushDownAggOp;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
@@ -42,7 +42,7 @@ import org.junit.jupiter.api.Test;
import java.util.Collections;
import java.util.Optional;
-public class PhysicalStorageLayerAggregateTest implements
GeneratedMemoPatterns {
+public class PhysicalStorageLayerAggregateTest implements
MemoPatternMatchSupported {
@Test
public void testWithoutProject() {
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java
index a22d2b29f5..73156a6cfd 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java
@@ -26,8 +26,6 @@ import org.apache.doris.catalog.RangePartitionInfo;
import org.apache.doris.catalog.RangePartitionItem;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.jmockit.Deencapsulation;
-import org.apache.doris.nereids.CascadesContext;
-import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
@@ -36,20 +34,19 @@ import
org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
-import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Lists;
import com.google.common.collect.Range;
import mockit.Expectations;
import mockit.Mocked;
-import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
@@ -58,10 +55,10 @@ import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
-class PruneOlapScanPartitionTest {
+class PruneOlapScanPartitionTest implements MemoPatternMatchSupported {
@Test
- public void testOlapScanPartitionWithSingleColumnCase(@Mocked OlapTable
olapTable) throws Exception {
+ void testOlapScanPartitionWithSingleColumnCase(@Mocked OlapTable
olapTable) throws Exception {
List<Column> columnNameList = new ArrayList<>();
columnNameList.add(new Column("col1", Type.INT.getPrimitiveType()));
columnNameList.add(new Column("col2", Type.INT.getPrimitiveType()));
@@ -92,24 +89,29 @@ class PruneOlapScanPartitionTest {
Expression expression = new LessThan(slotRef, new IntegerLiteral(4));
LogicalFilter<LogicalOlapScan> filter = new
LogicalFilter<>(ImmutableSet.of(expression), scan);
- CascadesContext cascadesContext =
MemoTestUtils.createCascadesContext(filter);
- List<Rule> rules = Lists.newArrayList(new
PruneOlapScanPartition().build());
- cascadesContext.topDownRewrite(rules);
- Plan resultPlan = cascadesContext.getMemo().copyOut();
- LogicalOlapScan rewrittenOlapScan = (LogicalOlapScan)
resultPlan.child(0);
- Assertions.assertEquals(0L,
rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
+ PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
+ .applyTopDown(new PruneOlapScanPartition())
+ .matches(
+ logicalFilter(
+ logicalOlapScan().when(
+ olapScan ->
olapScan.getSelectedPartitionIds().iterator().next() == 0L)
+ )
+ );
Expression lessThan0 = new LessThan(slotRef, new IntegerLiteral(0));
Expression greaterThan6 = new GreaterThan(slotRef, new
IntegerLiteral(6));
Or lessThan0OrGreaterThan6 = new Or(lessThan0, greaterThan6);
filter = new LogicalFilter<>(ImmutableSet.of(lessThan0OrGreaterThan6),
scan);
scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
olapTable);
- cascadesContext = MemoTestUtils.createCascadesContext(filter);
- rules = Lists.newArrayList(new PruneOlapScanPartition().build());
- cascadesContext.topDownRewrite(rules);
- resultPlan = cascadesContext.getMemo().copyOut();
- rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0);
- Assertions.assertEquals(1L,
rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
+ .applyTopDown(new PruneOlapScanPartition())
+ .matches(
+ logicalFilter(
+ logicalOlapScan().when(
+ olapScan ->
olapScan.getSelectedPartitionIds().iterator().next() == 1L)
+ )
+ );
Expression greaterThanEqual0 =
new GreaterThanEqual(
@@ -118,17 +120,20 @@ class PruneOlapScanPartitionTest {
new LessThanEqual(slotRef, new IntegerLiteral(5));
scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
olapTable);
filter = new LogicalFilter<>(ImmutableSet.of(greaterThanEqual0,
lessThanEqual5), scan);
- cascadesContext = MemoTestUtils.createCascadesContext(filter);
- rules = Lists.newArrayList(new PruneOlapScanPartition().build());
- cascadesContext.topDownRewrite(rules);
- resultPlan = cascadesContext.getMemo().copyOut();
- rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0);
- Assertions.assertEquals(0L,
rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
- Assertions.assertEquals(2,
rewrittenOlapScan.getSelectedPartitionIds().toArray().length);
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
+ .applyTopDown(new PruneOlapScanPartition())
+ .matches(
+ logicalFilter(
+ logicalOlapScan().when(
+ olapScan ->
olapScan.getSelectedPartitionIds().iterator().next() == 0L)
+ .when(olapScan ->
olapScan.getSelectedPartitionIds().size() == 2)
+ )
+ );
}
@Test
- public void testOlapScanPartitionPruneWithMultiColumnCase(@Mocked
OlapTable olapTable) throws Exception {
+ void testOlapScanPartitionPruneWithMultiColumnCase(@Mocked OlapTable
olapTable) throws Exception {
List<Column> columnNameList = new ArrayList<>();
columnNameList.add(new Column("col1", Type.INT.getPrimitiveType()));
columnNameList.add(new Column("col2", Type.INT.getPrimitiveType()));
@@ -155,11 +160,14 @@ class PruneOlapScanPartitionTest {
Expression left = new LessThan(new SlotReference("col1",
IntegerType.INSTANCE), new IntegerLiteral(4));
Expression right = new GreaterThan(new SlotReference("col2",
IntegerType.INSTANCE), new IntegerLiteral(11));
LogicalFilter<LogicalOlapScan> filter = new
LogicalFilter<>(ImmutableSet.of(left, right), scan);
- CascadesContext cascadesContext =
MemoTestUtils.createCascadesContext(filter);
- List<Rule> rules = Lists.newArrayList(new
PruneOlapScanPartition().build());
- cascadesContext.topDownRewrite(rules);
- Plan resultPlan = cascadesContext.getMemo().copyOut();
- LogicalOlapScan rewrittenOlapScan = (LogicalOlapScan)
resultPlan.child(0);
- Assertions.assertEquals(0L,
rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
+ PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
+ .applyTopDown(new PruneOlapScanPartition())
+ .matches(
+ logicalFilter(
+ logicalOlapScan()
+ .when(
+ olapScan ->
olapScan.getSelectedPartitionIds().iterator().next() == 0L)
+ )
+ );
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java
index 7ee589a422..5d34aec5db 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java
@@ -29,7 +29,6 @@ import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Partition;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.Type;
-import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
@@ -41,7 +40,9 @@ import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.planner.PartitionColumnFilter;
import com.google.common.collect.ImmutableList;
@@ -54,10 +55,10 @@ import org.junit.jupiter.api.Test;
import java.util.List;
-public class PruneOlapScanTabletTest {
+class PruneOlapScanTabletTest implements MemoPatternMatchSupported {
@Test
- public void testPruneOlapScanTablet(@Mocked OlapTable olapTable,
+ void testPruneOlapScanTablet(@Mocked OlapTable olapTable,
@Mocked Partition partition, @Mocked MaterializedIndex index,
@Mocked HashDistributionInfo distributionInfo) {
List<Long> tabletIds = Lists.newArrayListWithExpectedSize(300);
@@ -153,11 +154,12 @@ public class PruneOlapScanTabletTest {
Assertions.assertEquals(0,
filter.child().getSelectedTabletIds().size());
- CascadesContext context = MemoTestUtils.createCascadesContext(filter);
- context.topDownRewrite(ImmutableList.of(new
PruneOlapScanTablet().build()));
-
- LogicalFilter<LogicalOlapScan> filter1 =
((LogicalFilter<LogicalOlapScan>) context.getMemo().copyOut());
- LogicalOlapScan olapScan = filter1.child();
- Assertions.assertEquals(19, olapScan.getSelectedTabletIds().size());
+ PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
+ .applyTopDown(new PruneOlapScanTablet())
+ .matches(
+ logicalFilter(
+ logicalOlapScan().when(scan ->
scan.getSelectedTabletIds().size() == 19)
+ )
+ );
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java
index 48dbbb7621..b04318b072 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java
@@ -17,52 +17,53 @@
package org.apache.doris.nereids.rules.rewrite.logical;
-import org.apache.doris.nereids.memo.Group;
-import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
+import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
-import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
-import org.apache.doris.nereids.util.PlanRewriter;
-import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import java.util.List;
-import java.util.Optional;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
-public class PushdownJoinOtherConditionTest {
+class PushdownJoinOtherConditionTest implements MemoPatternMatchSupported {
- private Plan rStudent;
- private Plan rScore;
+ private LogicalOlapScan rStudent;
+ private LogicalOlapScan rScore;
+
+ private List<Slot> rStudentSlots;
+
+ private List<Slot> rScoreSlots;
/**
* ut before.
*/
@BeforeAll
- public final void beforeAll() {
+ final void beforeAll() {
rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
PlanConstructor.student,
ImmutableList.of(""));
rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
PlanConstructor.score, ImmutableList.of(""));
+ rStudentSlots = rStudent.getOutput();
+ rScoreSlots = rScore.getOutput();
}
@Test
- public void oneSide() {
+ void oneSide() {
oneSide(JoinType.CROSS_JOIN, false);
oneSide(JoinType.INNER_JOIN, false);
oneSide(JoinType.LEFT_OUTER_JOIN, true);
@@ -74,46 +75,43 @@ public class PushdownJoinOtherConditionTest {
}
private void oneSide(JoinType joinType, boolean testRight) {
-
- Expression pushSide1 = new GreaterThan(rStudent.getOutput().get(1),
Literal.of(18));
- Expression pushSide2 = new GreaterThan(rStudent.getOutput().get(1),
Literal.of(50));
+ Expression pushSide1 = new GreaterThan(rStudentSlots.get(1),
Literal.of(18));
+ Expression pushSide2 = new GreaterThan(rStudentSlots.get(1),
Literal.of(50));
List<Expression> condition = ImmutableList.of(pushSide1, pushSide2);
- Plan left = rStudent;
- Plan right = rScore;
+ LogicalOlapScan left = rStudent;
+ LogicalOlapScan right = rScore;
if (testRight) {
left = rScore;
right = rStudent;
}
- Plan join = new LogicalJoin<>(joinType,
ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(),
left, right);
- Plan root = new LogicalProject<>(Lists.newArrayList(), join);
+ LogicalPlan root = new LogicalPlanBuilder(left)
+ .join(right, joinType, ExpressionUtils.EMPTY_CONDITION,
condition)
+ .project(Lists.newArrayList())
+ .build();
- Memo memo = rewrite(root);
- Group rootGroup = memo.getRoot();
+ PlanChecker planChecker =
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new PushdownJoinOtherCondition());
- Plan shouldJoin =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
- Plan shouldFilter =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
- .child(0).getLogicalExpression().getPlan();
- Plan shouldScan =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
- .child(1).getLogicalExpression().getPlan();
if (testRight) {
- shouldFilter =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
- .child(1).getLogicalExpression().getPlan();
- shouldScan =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
- .child(0).getLogicalExpression().getPlan();
- }
-
- Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
- Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
- Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
- LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
+ planChecker.matches(
+ logicalJoin(
+ logicalOlapScan(),
+ logicalFilter()
+ .when(filter ->
ImmutableList.copyOf(filter.getConjuncts()).equals(condition))));
+ } else {
+ planChecker.matches(
+ logicalJoin(
+ logicalFilter().when(
+ filter ->
ImmutableList.copyOf(filter.getConjuncts()).equals(condition)),
+ logicalOlapScan()));
- Assertions.assertEquals(condition,
ImmutableList.copyOf(actualFilter.getConjuncts()));
+ }
}
@Test
- public void bothSideToBothSide() {
+ void bothSideToBothSide() {
bothSideToBothSide(JoinType.CROSS_JOIN);
bothSideToBothSide(JoinType.INNER_JOIN);
bothSideToBothSide(JoinType.LEFT_SEMI_JOIN);
@@ -122,34 +120,26 @@ public class PushdownJoinOtherConditionTest {
private void bothSideToBothSide(JoinType joinType) {
- Expression leftSide = new GreaterThan(rStudent.getOutput().get(1),
Literal.of(18));
- Expression rightSide = new GreaterThan(rScore.getOutput().get(2),
Literal.of(60));
+ Expression leftSide = new GreaterThan(rStudentSlots.get(1),
Literal.of(18));
+ Expression rightSide = new GreaterThan(rScoreSlots.get(2),
Literal.of(60));
List<Expression> condition = ImmutableList.of(leftSide, rightSide);
- Plan join = new LogicalJoin<>(joinType,
ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(),
rStudent,
- rScore);
- Plan root = new LogicalProject<>(Lists.newArrayList(), join);
-
- Memo memo = rewrite(root);
- Group rootGroup = memo.getRoot();
-
- Plan shouldJoin =
rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
- Plan leftFilter =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
- .child(0).getLogicalExpression().getPlan();
- Plan rightFilter =
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
- .child(1).getLogicalExpression().getPlan();
-
- Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
- Assertions.assertTrue(leftFilter instanceof LogicalFilter);
- Assertions.assertTrue(rightFilter instanceof LogicalFilter);
- LogicalFilter<Plan> actualLeft = (LogicalFilter<Plan>) leftFilter;
- LogicalFilter<Plan> actualRight = (LogicalFilter<Plan>) rightFilter;
- Assertions.assertEquals(ImmutableSet.of(leftSide),
actualLeft.getConjuncts());
- Assertions.assertEquals(ImmutableSet.of(rightSide),
actualRight.getConjuncts());
+ LogicalPlan root = new LogicalPlanBuilder(rStudent)
+ .join(rScore, joinType, ExpressionUtils.EMPTY_CONDITION,
condition)
+ .project(Lists.newArrayList())
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new PushdownJoinOtherCondition())
+ .matches(
+ logicalJoin(
+ logicalFilter().when(left ->
left.getConjuncts().equals(ImmutableSet.of(leftSide))),
+ logicalFilter().when(right ->
right.getConjuncts().equals(ImmutableSet.of(rightSide)))
+ ));
}
@Test
- public void bothSideToOneSide() {
+ void bothSideToOneSide() {
bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, true);
bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, true);
bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, false);
@@ -157,45 +147,37 @@ public class PushdownJoinOtherConditionTest {
}
private void bothSideToOneSide(JoinType joinType, boolean testRight) {
-
- Expression pushSide = new GreaterThan(rStudent.getOutput().get(1),
Literal.of(18));
- Expression reserveSide = new GreaterThan(rScore.getOutput().get(2),
Literal.of(60));
+ Expression pushSide = new GreaterThan(rStudentSlots.get(1),
Literal.of(18));
+ Expression reserveSide = new GreaterThan(rScoreSlots.get(2),
Literal.of(60));
List<Expression> condition = ImmutableList.of(pushSide, reserveSide);
- Plan left = rStudent;
- Plan right = rScore;
+ LogicalOlapScan left = rStudent;
+ LogicalOlapScan right = rScore;
if (testRight) {
left = rScore;
right = rStudent;
}
- Plan join = new LogicalJoin<>(joinType,
ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(),
left, right);
- Plan root = new LogicalProject<>(Lists.newArrayList(), join);
+ LogicalPlan root = new LogicalPlanBuilder(left)
+ .join(right, joinType, ExpressionUtils.EMPTY_CONDITION,
condition)
+ .project(Lists.newArrayList())
+ .build();
- Memo memo = rewrite(root);
- Group rootGroup = memo.getRoot();
+ PlanChecker planChecker =
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new PushdownJoinOtherCondition());
- Plan shouldJoin = rootGroup.getLogicalExpression()
- .child(0).getLogicalExpression().getPlan();
- Plan shouldFilter = rootGroup.getLogicalExpression()
-
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
- Plan shouldScan = rootGroup.getLogicalExpression()
-
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
if (testRight) {
- shouldFilter = rootGroup.getLogicalExpression()
-
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
- shouldScan = rootGroup.getLogicalExpression()
-
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
+ planChecker.matches(
+ logicalJoin(
+ logicalOlapScan(),
+ logicalFilter().when(filter ->
filter.getConjuncts().equals(ImmutableSet.of(pushSide)))
+ ));
+ } else {
+ planChecker.matches(
+ logicalJoin(
+ logicalFilter().when(filter ->
filter.getConjuncts().equals(ImmutableSet.of(pushSide))),
+ logicalOlapScan()
+ ));
}
-
- Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
- Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
- Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
- LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
- Assertions.assertEquals(ImmutableSet.of(pushSide),
actualFilter.getConjuncts());
- }
-
- private Memo rewrite(Plan plan) {
- return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new
PushdownJoinOtherCondition());
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimitTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimitTest.java
index 77bf76af4c..d66ede574f 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimitTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimitTest.java
@@ -27,10 +27,10 @@ import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
-public class PushdownProjectThroughLimitTest implements
MemoPatternMatchSupported {
+class PushdownProjectThroughLimitTest implements MemoPatternMatchSupported {
@Test
- public void testPushdownProjectThroughLimit() {
+ void testPushdownProjectThroughLimit() {
LogicalPlan project = new
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
.limit(1, 1)
.project(ImmutableList.of(0)) // id
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SplitLimitTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SplitLimitTest.java
index 174f5a90b4..3948406abd 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SplitLimitTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SplitLimitTest.java
@@ -17,25 +17,32 @@
package org.apache.doris.nereids.rules.rewrite.logical;
-import org.apache.doris.nereids.trees.plans.LimitPhase;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.junit.jupiter.api.Test;
-public class SplitLimitTest {
+/**
+ * Tests for {@link SplitLimit}.
+ */
+class SplitLimitTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan1 =
PlanConstructor.newLogicalOlapScan(0, "t1", 0);
@Test
void testSplitLimit() {
- Plan plan = new LogicalLimit<>(0, 0, LimitPhase.ORIGIN, scan1);
- plan = PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .rewrite()
- .getPlan();
- plan.anyMatch(x -> x instanceof LogicalLimit && ((LogicalLimit<?>)
x).isSplit());
+ LogicalPlan limit = new LogicalPlanBuilder(scan1)
+ .limit(0, 0)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), limit)
+ .applyTopDown(new SplitLimit())
+ .matches(
+ globalLogicalLimit(localLogicalLimit())
+ );
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
index f49711aa89..ce713c2f4d 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
@@ -55,6 +55,7 @@ import
org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.OriginStatement;
@@ -126,6 +127,12 @@ public class PlanChecker {
return applyTopDown(ruleFactory.buildRules());
}
+ public PlanChecker applyTopDown(CustomRewriter customRewriter) {
+ cascadesContext.topDownRewrite(customRewriter);
+ MemoValidator.validate(cascadesContext.getMemo());
+ return this;
+ }
+
public PlanChecker applyTopDown(List<Rule> rule) {
cascadesContext.topDownRewrite(rule);
MemoValidator.validate(cascadesContext.getMemo());
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]