This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 11fbe07221 [refactor](Nereids) Refactor all rewrite logical unit tests 
by match-pattern (#17691)
11fbe07221 is described below

commit 11fbe072213e1bb007b2e9d1b97217e22045df27
Author: Weijie Guo <[email protected]>
AuthorDate: Sun Mar 12 18:49:12 2023 +0800

    [refactor](Nereids) Refactor all rewrite logical unit tests by 
match-pattern (#17691)
---
 .../logical/EliminateGroupByConstantTest.java      | 196 ++++++++++-----------
 .../logical/EliminateUnnecessaryProjectTest.java   |  49 ++----
 .../logical/FindHashConditionForJoinTest.java      |  32 ++--
 .../rules/rewrite/logical/MergeFiltersTest.java    |  47 +++--
 .../rules/rewrite/logical/MergeLimitsTest.java     |  46 +++--
 .../logical/PhysicalStorageLayerAggregateTest.java |   4 +-
 .../logical/PruneOlapScanPartitionTest.java        |  74 ++++----
 .../rewrite/logical/PruneOlapScanTabletTest.java   |  20 ++-
 .../logical/PushdownJoinOtherConditionTest.java    | 170 ++++++++----------
 .../logical/PushdownProjectThroughLimitTest.java   |   4 +-
 .../rules/rewrite/logical/SplitLimitTest.java      |  25 ++-
 .../org/apache/doris/nereids/util/PlanChecker.java |   7 +
 12 files changed, 323 insertions(+), 351 deletions(-)

diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstantTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstantTest.java
index 9aa4a027a8..4af6f2234b 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstantTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstantTest.java
@@ -23,30 +23,29 @@ import org.apache.doris.catalog.KeysType;
 import org.apache.doris.catalog.OlapTable;
 import org.apache.doris.catalog.PartitionInfo;
 import org.apache.doris.catalog.Type;
-import org.apache.doris.nereids.CascadesContext;
 import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
 import org.apache.doris.nereids.trees.expressions.Add;
 import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
 import org.apache.doris.nereids.trees.plans.ObjectId;
-import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.thrift.TStorageType;
 
 import com.google.common.collect.ImmutableList;
-import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
-import java.util.List;
-
-public class EliminateGroupByConstantTest {
+/** Tests for {@link EliminateGroupByConstant}. */
+class EliminateGroupByConstantTest implements MemoPatternMatchSupported {
     private static final OlapTable table = new OlapTable(0L, "student",
             ImmutableList.of(new Column("k1", Type.INT, true, 
AggregateType.NONE, "0", ""),
                     new Column("k2", Type.INT, false, AggregateType.NONE, "0", 
""),
@@ -65,110 +64,107 @@ public class EliminateGroupByConstantTest {
     }
 
     @Test
-    public void testIntegerLiteral() {
-        LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
-                ImmutableList.of(new IntegerLiteral(1), k2),
-                ImmutableList.of(k1, k2),
-                new LogicalOlapScan(ObjectId.createGenerator().getNextId(), 
table)
-        );
-
-        CascadesContext context = 
MemoTestUtils.createCascadesContext(aggregate);
-        context.topDownRewrite(new EliminateGroupByConstant().build());
-        context.bottomUpRewrite(new CheckAfterRewrite().build());
-
-        LogicalAggregate aggregate1 = ((LogicalAggregate) 
context.getMemo().copyOut());
-        Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1);
-        Assertions.assertTrue(aggregate1.getGroupByExpressions().get(0) 
instanceof Slot);
+    void testIntegerLiteral() {
+        LogicalPlan aggregate = new LogicalPlanBuilder(
+                new LogicalOlapScan(ObjectId.createGenerator().getNextId(), 
table))
+                .agg(ImmutableList.of(new IntegerLiteral(1), k2),
+                     ImmutableList.of(k1, k2))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
+                .applyTopDown(new EliminateGroupByConstant())
+                .applyBottomUp(new CheckAfterRewrite())
+                .matches(
+                        aggregate().when(agg -> 
agg.getGroupByExpressions().equals(ImmutableList.of(k2)))
+                );
     }
 
     @Test
-    public void testOtherLiteral() {
-        LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
-                ImmutableList.of(
-                        new StringLiteral("str"), k2),
-                ImmutableList.of(
-                        new Alias(new StringLiteral("str"), "str"), k1, k2),
-                new LogicalOlapScan(ObjectId.createGenerator().getNextId(), 
table)
-        );
-
-        CascadesContext context = 
MemoTestUtils.createCascadesContext(aggregate);
-        context.topDownRewrite(new EliminateGroupByConstant().build());
-        context.bottomUpRewrite(new CheckAfterRewrite().build());
-
-        LogicalAggregate aggregate1 = ((LogicalAggregate) 
context.getMemo().copyOut());
-        Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1);
-        Assertions.assertTrue(aggregate1.getGroupByExpressions().get(0) 
instanceof Slot);
+    void testOtherLiteral() {
+        LogicalPlan aggregate = new LogicalPlanBuilder(
+                new LogicalOlapScan(ObjectId.createGenerator().getNextId(), 
table))
+                .agg(ImmutableList.of(
+                             new StringLiteral("str"), k2),
+                     ImmutableList.of(
+                             new Alias(new StringLiteral("str"), "str"), k1, 
k2))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
+                .applyTopDown(new EliminateGroupByConstant())
+                .applyBottomUp(new CheckAfterRewrite())
+                .matches(
+                        aggregate().when(agg -> 
agg.getGroupByExpressions().equals(ImmutableList.of(k2)))
+                );
     }
 
     @Test
-    public void testMixedLiteral() {
-        LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
-                ImmutableList.of(
-                        new StringLiteral("str"), k2,
-                        new IntegerLiteral(1),
-                        new IntegerLiteral(2),
-                        new IntegerLiteral(3),
-                        new Add(k1, k2)),
-                ImmutableList.of(
-                        new Alias(new StringLiteral("str"), "str"),
-                        k2, k1, new Alias(new IntegerLiteral(1), "integer")),
-                new LogicalOlapScan(ObjectId.createGenerator().getNextId(), 
table)
-        );
-
-        CascadesContext context = 
MemoTestUtils.createCascadesContext(aggregate);
-        context.topDownRewrite(new EliminateGroupByConstant().build());
-        context.bottomUpRewrite(new CheckAfterRewrite().build());
-
-        LogicalAggregate aggregate1 = ((LogicalAggregate) 
context.getMemo().copyOut());
-        Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 2);
-        List groupByExprs = aggregate1.getGroupByExpressions();
-        Assertions.assertTrue(groupByExprs.get(0) instanceof Slot
-                && groupByExprs.get(1) instanceof Add);
+    void testMixedLiteral() {
+        LogicalPlan aggregate = new LogicalPlanBuilder(
+                new LogicalOlapScan(ObjectId.createGenerator().getNextId(), 
table))
+                .agg(ImmutableList.of(
+                             new StringLiteral("str"), k2,
+                             new IntegerLiteral(1),
+                             new IntegerLiteral(2),
+                             new IntegerLiteral(3),
+                             new Add(k1, k2)),
+                     ImmutableList.of(
+                             new Alias(new StringLiteral("str"), "str"),
+                             k2, k1, new Alias(new IntegerLiteral(1), 
"integer")))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
+                .applyTopDown(new EliminateGroupByConstant())
+                .applyBottomUp(new CheckAfterRewrite())
+                .matches(
+                        aggregate()
+                                .when(agg -> 
agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2))))
+                );
     }
 
     @Test
-    public void testComplexGroupBy() {
-        LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
-                ImmutableList.of(
-                        new IntegerLiteral(1),
-                        new IntegerLiteral(2),
-                        new Add(k1, k2)),
-                ImmutableList.of(
-                        new Alias(new Max(k1), "max"),
-                        new Alias(new Min(k2), "min"),
-                        new Alias(new Add(k1, k2), "add")),
-                new LogicalOlapScan(ObjectId.createGenerator().getNextId(), 
table)
-        );
-
-        CascadesContext context = 
MemoTestUtils.createCascadesContext(aggregate);
-        context.topDownRewrite(new EliminateGroupByConstant().build());
-        context.bottomUpRewrite(new CheckAfterRewrite().build());
-
-        LogicalAggregate aggregate1 = ((LogicalAggregate) 
context.getMemo().copyOut());
-        Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1);
+    void testComplexGroupBy() {
+        LogicalPlan aggregate = new LogicalPlanBuilder(
+                new LogicalOlapScan(ObjectId.createGenerator().getNextId(), 
table))
+                .agg(ImmutableList.of(
+                             new IntegerLiteral(1),
+                             new IntegerLiteral(2),
+                             new Add(k1, k2)),
+                     ImmutableList.of(
+                             new Alias(new Max(k1), "max"),
+                             new Alias(new Min(k2), "min"),
+                             new Alias(new Add(k1, k2), "add")))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
+                .applyTopDown(new EliminateGroupByConstant())
+                .applyBottomUp(new CheckAfterRewrite())
+                .matches(
+                        aggregate()
+                                .when(agg -> 
agg.getGroupByExpressions().equals(ImmutableList.of(new Add(k1, k2))))
+                );
     }
 
     @Test
-    public void testOutOfRange() {
-        LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
-                ImmutableList.of(
-                        new StringLiteral("str"), k2,
-                        new IntegerLiteral(1),
-                        new IntegerLiteral(2),
-                        new IntegerLiteral(3),
-                        new IntegerLiteral(5),
-                        new Add(k1, k2)),
-                ImmutableList.of(
-                        new Alias(new StringLiteral("str"), "str"),
-                        k2, k1, new Alias(new IntegerLiteral(1), "integer")),
-                new LogicalOlapScan(ObjectId.createGenerator().getNextId(), 
table)
-        );
-
-        CascadesContext context = 
MemoTestUtils.createCascadesContext(aggregate);
-        context.topDownRewrite(new EliminateGroupByConstant().build());
-        context.bottomUpRewrite(new CheckAfterRewrite().build());
-
-        LogicalAggregate aggregate1 = ((LogicalAggregate) 
context.getMemo().copyOut());
-        Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 2);
+    void testOutOfRange() {
+        LogicalPlan aggregate = new LogicalPlanBuilder(
+                new LogicalOlapScan(ObjectId.createGenerator().getNextId(), 
table))
+                .agg(ImmutableList.of(
+                             new StringLiteral("str"), k2,
+                             new IntegerLiteral(1),
+                             new IntegerLiteral(2),
+                             new IntegerLiteral(3),
+                             new IntegerLiteral(5),
+                             new Add(k1, k2)),
+                     ImmutableList.of(
+                                     new Alias(new StringLiteral("str"), 
"str"),
+                                     k2, k1, new Alias(new IntegerLiteral(1), 
"integer")))
+                .build();
+        PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
+                .applyTopDown(new EliminateGroupByConstant())
+                .applyBottomUp(new CheckAfterRewrite())
+                .matches(
+                        aggregate()
+                                .when(agg -> 
agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2))))
+                );
     }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java
index ecd6de1cd5..4435eb16f8 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java
@@ -17,28 +17,25 @@
 
 package org.apache.doris.nereids.rules.rewrite.logical;
 
-import org.apache.doris.nereids.CascadesContext;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
-import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.types.IntegerType;
 import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
 import org.apache.doris.utframe.TestWithFeService;
 
 import com.google.common.collect.ImmutableList;
-import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 /**
  * test ELIMINATE_UNNECESSARY_PROJECT rule.
  */
-public class EliminateUnnecessaryProjectTest extends TestWithFeService {
+class EliminateUnnecessaryProjectTest extends TestWithFeService implements 
MemoPatternMatchSupported {
 
     @Override
     protected void runBeforeAll() throws Exception {
@@ -55,57 +52,49 @@ public class EliminateUnnecessaryProjectTest extends 
TestWithFeService {
     }
 
     @Test
-    public void testEliminateNonTopUnnecessaryProject() {
+    void testEliminateNonTopUnnecessaryProject() {
         LogicalPlan unnecessaryProject = new 
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
                 .project(ImmutableList.of(1, 0))
                 .filter(BooleanLiteral.FALSE)
                 .build();
 
-        CascadesContext cascadesContext = 
MemoTestUtils.createCascadesContext(unnecessaryProject);
-        cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
-
-        Plan actual = cascadesContext.getMemo().copyOut();
-        Assertions.assertTrue(actual.child(0) instanceof LogicalProject);
+        PlanChecker.from(MemoTestUtils.createConnectContext(), 
unnecessaryProject)
+                .applyTopDown(new EliminateUnnecessaryProject())
+                .matchesFromRoot(logicalFilter(logicalProject()));
     }
 
     @Test
-    public void testEliminateTopUnnecessaryProject() {
+    void testEliminateTopUnnecessaryProject() {
         LogicalPlan unnecessaryProject = new 
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
                 .project(ImmutableList.of(0, 1))
                 .build();
 
-        CascadesContext cascadesContext = 
MemoTestUtils.createCascadesContext(unnecessaryProject);
-        cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
-
-        Plan actual = cascadesContext.getMemo().copyOut();
-        Assertions.assertTrue(actual instanceof LogicalOlapScan);
+        PlanChecker.from(MemoTestUtils.createConnectContext(), 
unnecessaryProject)
+                .applyTopDown(new EliminateUnnecessaryProject())
+                .matchesFromRoot(logicalOlapScan());
     }
 
     @Test
-    public void testNotEliminateTopProjectWhenOutputNotEquals() {
+    void testNotEliminateTopProjectWhenOutputNotEquals() {
         LogicalPlan necessaryProject = new 
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
                 .project(ImmutableList.of(1, 0))
                 .build();
 
-        CascadesContext cascadesContext = 
MemoTestUtils.createCascadesContext(necessaryProject);
-        cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
-
-        Plan actual = cascadesContext.getMemo().copyOut();
-        Assertions.assertTrue(actual instanceof LogicalProject);
+        PlanChecker.from(MemoTestUtils.createConnectContext(), 
necessaryProject)
+                .applyTopDown(new EliminateUnnecessaryProject())
+                .matchesFromRoot(logicalProject());
     }
 
     @Test
-    public void testEliminateProjectWhenEmptyRelationChild() {
+    void testEliminateProjectWhenEmptyRelationChild() {
         LogicalPlan unnecessaryProject = new LogicalPlanBuilder(new 
LogicalEmptyRelation(ImmutableList.of(
                 new SlotReference("k1", IntegerType.INSTANCE),
                 new SlotReference("k2", IntegerType.INSTANCE))))
                 .project(ImmutableList.of(1, 0))
                 .build();
-        CascadesContext cascadesContext = 
MemoTestUtils.createCascadesContext(unnecessaryProject);
-        cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
-
-        Plan actual = cascadesContext.getMemo().copyOut();
-        Assertions.assertTrue(actual instanceof LogicalEmptyRelation);
+        PlanChecker.from(MemoTestUtils.createConnectContext(), 
unnecessaryProject)
+                .applyTopDown(new EliminateUnnecessaryProject())
+                .matchesFromRoot(logicalEmptyRelation());
     }
 
     // TODO: uncomment this after the Elimination project rule is correctly 
implemented
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java
index 020e1d79d6..8bf0bada3f 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java
@@ -17,8 +17,6 @@
 
 package org.apache.doris.nereids.rules.rewrite.logical;
 
-import org.apache.doris.nereids.CascadesContext;
-import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.trees.expressions.Add;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
@@ -31,12 +29,12 @@ import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
+import org.apache.doris.qe.ConnectContext;
 
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 import java.util.ArrayList;
@@ -54,9 +52,9 @@ import java.util.Optional;
  *       -hashJoinConjuncts={A.x=B.x, A.y+1=B.y}
  *       -otherJoinCondition="A.x=1 and (A.x=1 or B.x=A.x) and A.x>B.x"
  */
-class FindHashConditionForJoinTest {
+class FindHashConditionForJoinTest implements MemoPatternMatchSupported {
     @Test
-    public void testFindHashCondition() {
+    void testFindHashCondition() {
         Plan student = new 
LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student,
                 ImmutableList.of(""));
         Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(), 
PlanConstructor.score,
@@ -77,20 +75,12 @@ class FindHashConditionForJoinTest {
         List<Expression> expr = ImmutableList.of(eq1, eq2, eq3, or, less);
         LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, new 
ArrayList<>(),
                 expr, JoinHint.NONE, Optional.empty(), student, score);
-        CascadesContext context = MemoTestUtils.createCascadesContext(join);
-        List<Rule> rules = Lists.newArrayList(new 
FindHashConditionForJoin().build());
 
-        context.topDownRewrite(rules);
-        Plan plan = context.getMemo().copyOut();
-        LogicalJoin after = (LogicalJoin) plan;
-        Assertions.assertEquals(after.getHashJoinConjuncts().size(), 2);
-        Assertions.assertTrue(after.getHashJoinConjuncts().contains(eq1));
-        Assertions.assertTrue(after.getHashJoinConjuncts().contains(eq3));
-        List<Expression> others = after.getOtherJoinConjuncts();
-        Assertions.assertEquals(others.size(), 3);
-        Assertions.assertTrue(others.contains(less));
-        Assertions.assertTrue(others.contains(eq2));
-        Assertions.assertTrue(others.contains(less));
+        PlanChecker.from(new ConnectContext(), join)
+                        .applyTopDown(new FindHashConditionForJoin())
+                        .matches(
+                            logicalJoin()
+                                    .when(j -> 
j.getHashJoinConjuncts().equals(ImmutableList.of(eq1, eq3)))
+                                    .when(j -> 
j.getOtherJoinConjuncts().equals(ImmutableList.of(eq2, or, less))));
     }
-
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java
index 31bf853bdf..e706151e0e 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java
@@ -17,47 +17,42 @@
 
 package org.apache.doris.nereids.rules.rewrite.logical;
 
-import org.apache.doris.nereids.CascadesContext;
 import org.apache.doris.nereids.analyzer.UnboundRelation;
-import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
-import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.qe.ConnectContext;
 
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
-import java.util.List;
-import java.util.Set;
-
 /**
- * MergeConsecutiveFilter ut
+ * Tests for {@link MergeFilters}.
  */
-public class MergeFiltersTest {
+class MergeFiltersTest implements MemoPatternMatchSupported {
     @Test
-    public void testMergeConsecutiveFilters() {
-        UnboundRelation relation = new 
UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", 
"table"));
+    void testMergeFilters() {
         Expression expression1 = new IntegerLiteral(1);
-        LogicalFilter filter1 = new 
LogicalFilter<>(ImmutableSet.of(expression1), relation);
         Expression expression2 = new IntegerLiteral(2);
-        LogicalFilter filter2 = new 
LogicalFilter<>(ImmutableSet.of(expression2), filter1);
         Expression expression3 = new IntegerLiteral(3);
-        LogicalFilter filter3 = new 
LogicalFilter<>(ImmutableSet.of(expression3), filter2);
 
-        CascadesContext cascadesContext = 
MemoTestUtils.createCascadesContext(filter3);
-        List<Rule> rules = Lists.newArrayList(new MergeFilters().build());
-        cascadesContext.bottomUpRewrite(rules);
-        //check transformed plan
-        Plan resultPlan = cascadesContext.getMemo().copyOut();
-        System.out.println(resultPlan.treeString());
-        Assertions.assertTrue(resultPlan instanceof LogicalFilter);
-        Set<Expression> allPredicates = ImmutableSet.of(expression1, 
expression2, expression3);
-        Assertions.assertEquals(ImmutableSet.copyOf(((LogicalFilter<?>) 
resultPlan).getConjuncts()), allPredicates);
-        Assertions.assertTrue(resultPlan.child(0) instanceof UnboundRelation);
+        LogicalPlan logicalFilter = new LogicalPlanBuilder(
+                new UnboundRelation(RelationUtil.newRelationId(), 
Lists.newArrayList("db", "table")))
+                .filter(ImmutableSet.of(expression1))
+                .filter(ImmutableSet.of(expression2))
+                .filter(ImmutableSet.of(expression3))
+                .build();
+
+        PlanChecker.from(new ConnectContext(), 
logicalFilter).applyBottomUp(new MergeFilters())
+                .matches(
+                        logicalFilter(
+                                unboundRelation()
+                        ).when(filter -> filter.getConjuncts()
+                                .equals(ImmutableSet.of(expression1, 
expression2, expression3))));
     }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java
index 869dec982f..ae608f1d26 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java
@@ -17,37 +17,33 @@
 
 package org.apache.doris.nereids.rules.rewrite.logical;
 
-import org.apache.doris.nereids.CascadesContext;
 import org.apache.doris.nereids.analyzer.UnboundRelation;
-import org.apache.doris.nereids.rules.Rule;
-import org.apache.doris.nereids.trees.plans.LimitPhase;
-import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
-import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.qe.ConnectContext;
 
 import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
-import java.util.List;
-
-public class MergeLimitsTest {
+/**
+ * Tests for {@link MergeLimits}.
+ */
+class MergeLimitsTest implements MemoPatternMatchSupported {
     @Test
-    public void testMergeConsecutiveLimits() {
-        LogicalLimit limit3 = new LogicalLimit<>(3, 5, LimitPhase.ORIGIN, new 
UnboundRelation(
-                RelationUtil.newRelationId(), Lists.newArrayList("db", "t")));
-        LogicalLimit limit2 = new LogicalLimit<>(2, 0, LimitPhase.ORIGIN, 
limit3);
-        LogicalLimit limit1 = new LogicalLimit<>(10, 0, LimitPhase.ORIGIN, 
limit2);
-
-        CascadesContext context = MemoTestUtils.createCascadesContext(limit1);
-        List<Rule> rules = Lists.newArrayList(new MergeLimits().build());
-        context.topDownRewrite(rules);
-        LogicalLimit limit = (LogicalLimit) context.getMemo().copyOut();
-
-        Assertions.assertEquals(2, limit.getLimit());
-        Assertions.assertEquals(5, limit.getOffset());
-        Assertions.assertEquals(1, limit.children().size());
-        Assertions.assertTrue(limit.child(0) instanceof UnboundRelation);
-
+    void testMergeLimits() {
+        LogicalPlan logicalLimit = new LogicalPlanBuilder(
+                new UnboundRelation(RelationUtil.newRelationId(), 
Lists.newArrayList("db", "t")))
+                .limit(3, 5)
+                .limit(2, 0)
+                .limit(10, 0).build();
+
+        PlanChecker.from(new ConnectContext(), logicalLimit).applyTopDown(new 
MergeLimits())
+                .matches(
+                        logicalLimit(
+                                unboundRelation()
+                        ).when(limit -> limit.getLimit() == 2).when(limit -> 
limit.getOffset() == 5));
     }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java
index 1556438652..69e07eebc7 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java
@@ -18,7 +18,6 @@
 package org.apache.doris.nereids.rules.rewrite.logical;
 
 import org.apache.doris.nereids.CascadesContext;
-import org.apache.doris.nereids.pattern.GeneratedMemoPatterns;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RulePromise;
 import org.apache.doris.nereids.rules.RuleType;
@@ -32,6 +31,7 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import 
org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate.PushDownAggOp;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.MemoTestUtils;
 import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
@@ -42,7 +42,7 @@ import org.junit.jupiter.api.Test;
 import java.util.Collections;
 import java.util.Optional;
 
-public class PhysicalStorageLayerAggregateTest implements 
GeneratedMemoPatterns {
+public class PhysicalStorageLayerAggregateTest implements 
MemoPatternMatchSupported {
 
     @Test
     public void testWithoutProject() {
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java
index a22d2b29f5..73156a6cfd 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java
@@ -26,8 +26,6 @@ import org.apache.doris.catalog.RangePartitionInfo;
 import org.apache.doris.catalog.RangePartitionItem;
 import org.apache.doris.catalog.Type;
 import org.apache.doris.common.jmockit.Deencapsulation;
-import org.apache.doris.nereids.CascadesContext;
-import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.GreaterThan;
 import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
@@ -36,20 +34,19 @@ import 
org.apache.doris.nereids.trees.expressions.LessThanEqual;
 import org.apache.doris.nereids.trees.expressions.Or;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
-import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
 
 import com.google.common.collect.BoundType;
 import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Lists;
 import com.google.common.collect.Range;
 import mockit.Expectations;
 import mockit.Mocked;
-import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 import java.util.ArrayList;
@@ -58,10 +55,10 @@ import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 
-class PruneOlapScanPartitionTest {
+class PruneOlapScanPartitionTest implements MemoPatternMatchSupported {
 
     @Test
-    public void testOlapScanPartitionWithSingleColumnCase(@Mocked OlapTable 
olapTable) throws Exception {
+    void testOlapScanPartitionWithSingleColumnCase(@Mocked OlapTable 
olapTable) throws Exception {
         List<Column> columnNameList = new ArrayList<>();
         columnNameList.add(new Column("col1", Type.INT.getPrimitiveType()));
         columnNameList.add(new Column("col2", Type.INT.getPrimitiveType()));
@@ -92,24 +89,29 @@ class PruneOlapScanPartitionTest {
         Expression expression = new LessThan(slotRef, new IntegerLiteral(4));
         LogicalFilter<LogicalOlapScan> filter = new 
LogicalFilter<>(ImmutableSet.of(expression), scan);
 
-        CascadesContext cascadesContext = 
MemoTestUtils.createCascadesContext(filter);
-        List<Rule> rules = Lists.newArrayList(new 
PruneOlapScanPartition().build());
-        cascadesContext.topDownRewrite(rules);
-        Plan resultPlan = cascadesContext.getMemo().copyOut();
-        LogicalOlapScan rewrittenOlapScan = (LogicalOlapScan) 
resultPlan.child(0);
-        Assertions.assertEquals(0L, 
rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
+        PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
+                .applyTopDown(new PruneOlapScanPartition())
+                .matches(
+                        logicalFilter(
+                                logicalOlapScan().when(
+                                        olapScan -> 
olapScan.getSelectedPartitionIds().iterator().next() == 0L)
+                        )
+                );
 
         Expression lessThan0 = new LessThan(slotRef, new IntegerLiteral(0));
         Expression greaterThan6 = new GreaterThan(slotRef, new 
IntegerLiteral(6));
         Or lessThan0OrGreaterThan6 = new Or(lessThan0, greaterThan6);
         filter = new LogicalFilter<>(ImmutableSet.of(lessThan0OrGreaterThan6), 
scan);
         scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), 
olapTable);
-        cascadesContext = MemoTestUtils.createCascadesContext(filter);
-        rules = Lists.newArrayList(new PruneOlapScanPartition().build());
-        cascadesContext.topDownRewrite(rules);
-        resultPlan = cascadesContext.getMemo().copyOut();
-        rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0);
-        Assertions.assertEquals(1L, 
rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
+                .applyTopDown(new PruneOlapScanPartition())
+                .matches(
+                        logicalFilter(
+                                logicalOlapScan().when(
+                                        olapScan -> 
olapScan.getSelectedPartitionIds().iterator().next() == 1L)
+                        )
+                );
 
         Expression greaterThanEqual0 =
                 new GreaterThanEqual(
@@ -118,17 +120,20 @@ class PruneOlapScanPartitionTest {
                 new LessThanEqual(slotRef, new IntegerLiteral(5));
         scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), 
olapTable);
         filter = new LogicalFilter<>(ImmutableSet.of(greaterThanEqual0, 
lessThanEqual5), scan);
-        cascadesContext = MemoTestUtils.createCascadesContext(filter);
-        rules = Lists.newArrayList(new PruneOlapScanPartition().build());
-        cascadesContext.topDownRewrite(rules);
-        resultPlan = cascadesContext.getMemo().copyOut();
-        rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0);
-        Assertions.assertEquals(0L, 
rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
-        Assertions.assertEquals(2, 
rewrittenOlapScan.getSelectedPartitionIds().toArray().length);
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
+                .applyTopDown(new PruneOlapScanPartition())
+                .matches(
+                        logicalFilter(
+                                logicalOlapScan().when(
+                                                olapScan -> 
olapScan.getSelectedPartitionIds().iterator().next() == 0L)
+                                        .when(olapScan -> 
olapScan.getSelectedPartitionIds().size() == 2)
+                        )
+                );
     }
 
     @Test
-    public void testOlapScanPartitionPruneWithMultiColumnCase(@Mocked 
OlapTable olapTable) throws Exception {
+    void testOlapScanPartitionPruneWithMultiColumnCase(@Mocked OlapTable 
olapTable) throws Exception {
         List<Column> columnNameList = new ArrayList<>();
         columnNameList.add(new Column("col1", Type.INT.getPrimitiveType()));
         columnNameList.add(new Column("col2", Type.INT.getPrimitiveType()));
@@ -155,11 +160,14 @@ class PruneOlapScanPartitionTest {
         Expression left = new LessThan(new SlotReference("col1", 
IntegerType.INSTANCE), new IntegerLiteral(4));
         Expression right = new GreaterThan(new SlotReference("col2", 
IntegerType.INSTANCE), new IntegerLiteral(11));
         LogicalFilter<LogicalOlapScan> filter = new 
LogicalFilter<>(ImmutableSet.of(left, right), scan);
-        CascadesContext cascadesContext = 
MemoTestUtils.createCascadesContext(filter);
-        List<Rule> rules = Lists.newArrayList(new 
PruneOlapScanPartition().build());
-        cascadesContext.topDownRewrite(rules);
-        Plan resultPlan = cascadesContext.getMemo().copyOut();
-        LogicalOlapScan rewrittenOlapScan = (LogicalOlapScan) 
resultPlan.child(0);
-        Assertions.assertEquals(0L, 
rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
+        PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
+                .applyTopDown(new PruneOlapScanPartition())
+                .matches(
+                        logicalFilter(
+                                logicalOlapScan()
+                                        .when(
+                                                olapScan -> 
olapScan.getSelectedPartitionIds().iterator().next() == 0L)
+                        )
+                );
     }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java
index 7ee589a422..5d34aec5db 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java
@@ -29,7 +29,6 @@ import org.apache.doris.catalog.OlapTable;
 import org.apache.doris.catalog.Partition;
 import org.apache.doris.catalog.PrimitiveType;
 import org.apache.doris.catalog.Type;
-import org.apache.doris.nereids.CascadesContext;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
 import org.apache.doris.nereids.trees.expressions.InPredicate;
@@ -41,7 +40,9 @@ import org.apache.doris.nereids.trees.plans.ObjectId;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.planner.PartitionColumnFilter;
 
 import com.google.common.collect.ImmutableList;
@@ -54,10 +55,10 @@ import org.junit.jupiter.api.Test;
 
 import java.util.List;
 
-public class PruneOlapScanTabletTest {
+class PruneOlapScanTabletTest implements MemoPatternMatchSupported {
 
     @Test
-    public void testPruneOlapScanTablet(@Mocked OlapTable olapTable,
+    void testPruneOlapScanTablet(@Mocked OlapTable olapTable,
             @Mocked Partition partition, @Mocked MaterializedIndex index,
             @Mocked HashDistributionInfo distributionInfo) {
         List<Long> tabletIds = Lists.newArrayListWithExpectedSize(300);
@@ -153,11 +154,12 @@ public class PruneOlapScanTabletTest {
 
         Assertions.assertEquals(0, 
filter.child().getSelectedTabletIds().size());
 
-        CascadesContext context = MemoTestUtils.createCascadesContext(filter);
-        context.topDownRewrite(ImmutableList.of(new 
PruneOlapScanTablet().build()));
-
-        LogicalFilter<LogicalOlapScan> filter1 = 
((LogicalFilter<LogicalOlapScan>) context.getMemo().copyOut());
-        LogicalOlapScan olapScan = filter1.child();
-        Assertions.assertEquals(19, olapScan.getSelectedTabletIds().size());
+        PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
+                .applyTopDown(new PruneOlapScanTablet())
+                .matches(
+                        logicalFilter(
+                                logicalOlapScan().when(scan -> 
scan.getSelectedTabletIds().size() == 19)
+                        )
+                );
     }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java
index 48dbbb7621..b04318b072 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java
@@ -17,52 +17,53 @@
 
 package org.apache.doris.nereids.rules.rewrite.logical;
 
-import org.apache.doris.nereids.memo.Group;
-import org.apache.doris.nereids.memo.Memo;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.GreaterThan;
+import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
-import org.apache.doris.nereids.trees.plans.JoinHint;
 import org.apache.doris.nereids.trees.plans.JoinType;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
-import org.apache.doris.nereids.util.PlanRewriter;
-import org.apache.doris.qe.ConnectContext;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.TestInstance;
 
 import java.util.List;
-import java.util.Optional;
 
 @TestInstance(TestInstance.Lifecycle.PER_CLASS)
-public class PushdownJoinOtherConditionTest {
+class PushdownJoinOtherConditionTest implements MemoPatternMatchSupported {
 
-    private Plan rStudent;
-    private Plan rScore;
+    private LogicalOlapScan rStudent;
+    private LogicalOlapScan rScore;
+
+    private List<Slot> rStudentSlots;
+
+    private List<Slot> rScoreSlots;
 
     /**
      * ut before.
      */
     @BeforeAll
-    public final void beforeAll() {
+    final void beforeAll() {
         rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), 
PlanConstructor.student,
                 ImmutableList.of(""));
         rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), 
PlanConstructor.score, ImmutableList.of(""));
+        rStudentSlots = rStudent.getOutput();
+        rScoreSlots = rScore.getOutput();
     }
 
     @Test
-    public void oneSide() {
+    void oneSide() {
         oneSide(JoinType.CROSS_JOIN, false);
         oneSide(JoinType.INNER_JOIN, false);
         oneSide(JoinType.LEFT_OUTER_JOIN, true);
@@ -74,46 +75,43 @@ public class PushdownJoinOtherConditionTest {
     }
 
     private void oneSide(JoinType joinType, boolean testRight) {
-
-        Expression pushSide1 = new GreaterThan(rStudent.getOutput().get(1), 
Literal.of(18));
-        Expression pushSide2 = new GreaterThan(rStudent.getOutput().get(1), 
Literal.of(50));
+        Expression pushSide1 = new GreaterThan(rStudentSlots.get(1), 
Literal.of(18));
+        Expression pushSide2 = new GreaterThan(rStudentSlots.get(1), 
Literal.of(50));
         List<Expression> condition = ImmutableList.of(pushSide1, pushSide2);
 
-        Plan left = rStudent;
-        Plan right = rScore;
+        LogicalOlapScan left = rStudent;
+        LogicalOlapScan right = rScore;
         if (testRight) {
             left = rScore;
             right = rStudent;
         }
 
-        Plan join = new LogicalJoin<>(joinType, 
ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), 
left, right);
-        Plan root = new LogicalProject<>(Lists.newArrayList(), join);
+        LogicalPlan root = new LogicalPlanBuilder(left)
+                .join(right, joinType, ExpressionUtils.EMPTY_CONDITION, 
condition)
+                .project(Lists.newArrayList())
+                .build();
 
-        Memo memo = rewrite(root);
-        Group rootGroup = memo.getRoot();
+        PlanChecker planChecker = 
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new PushdownJoinOtherCondition());
 
-        Plan shouldJoin = 
rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
-        Plan shouldFilter = 
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
-                .child(0).getLogicalExpression().getPlan();
-        Plan shouldScan = 
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
-                .child(1).getLogicalExpression().getPlan();
         if (testRight) {
-            shouldFilter = 
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
-                    .child(1).getLogicalExpression().getPlan();
-            shouldScan = 
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
-                    .child(0).getLogicalExpression().getPlan();
-        }
-
-        Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
-        Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
-        Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
-        LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
+            planChecker.matches(
+                    logicalJoin(
+                            logicalOlapScan(),
+                            logicalFilter()
+                                    .when(filter -> 
ImmutableList.copyOf(filter.getConjuncts()).equals(condition))));
+        } else {
+            planChecker.matches(
+                    logicalJoin(
+                            logicalFilter().when(
+                                    filter -> 
ImmutableList.copyOf(filter.getConjuncts()).equals(condition)),
+                            logicalOlapScan()));
 
-        Assertions.assertEquals(condition, 
ImmutableList.copyOf(actualFilter.getConjuncts()));
+        }
     }
 
     @Test
-    public void bothSideToBothSide() {
+    void bothSideToBothSide() {
         bothSideToBothSide(JoinType.CROSS_JOIN);
         bothSideToBothSide(JoinType.INNER_JOIN);
         bothSideToBothSide(JoinType.LEFT_SEMI_JOIN);
@@ -122,34 +120,26 @@ public class PushdownJoinOtherConditionTest {
 
     private void bothSideToBothSide(JoinType joinType) {
 
-        Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), 
Literal.of(18));
-        Expression rightSide = new GreaterThan(rScore.getOutput().get(2), 
Literal.of(60));
+        Expression leftSide = new GreaterThan(rStudentSlots.get(1), 
Literal.of(18));
+        Expression rightSide = new GreaterThan(rScoreSlots.get(2), 
Literal.of(60));
         List<Expression> condition = ImmutableList.of(leftSide, rightSide);
 
-        Plan join = new LogicalJoin<>(joinType, 
ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), 
rStudent,
-                rScore);
-        Plan root = new LogicalProject<>(Lists.newArrayList(), join);
-
-        Memo memo = rewrite(root);
-        Group rootGroup = memo.getRoot();
-
-        Plan shouldJoin = 
rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
-        Plan leftFilter = 
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
-                .child(0).getLogicalExpression().getPlan();
-        Plan rightFilter = 
rootGroup.getLogicalExpression().child(0).getLogicalExpression()
-                .child(1).getLogicalExpression().getPlan();
-
-        Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
-        Assertions.assertTrue(leftFilter instanceof LogicalFilter);
-        Assertions.assertTrue(rightFilter instanceof LogicalFilter);
-        LogicalFilter<Plan> actualLeft = (LogicalFilter<Plan>) leftFilter;
-        LogicalFilter<Plan> actualRight = (LogicalFilter<Plan>) rightFilter;
-        Assertions.assertEquals(ImmutableSet.of(leftSide), 
actualLeft.getConjuncts());
-        Assertions.assertEquals(ImmutableSet.of(rightSide), 
actualRight.getConjuncts());
+        LogicalPlan root = new LogicalPlanBuilder(rStudent)
+                .join(rScore, joinType, ExpressionUtils.EMPTY_CONDITION, 
condition)
+                .project(Lists.newArrayList())
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new PushdownJoinOtherCondition())
+                .matches(
+                        logicalJoin(
+                                logicalFilter().when(left -> 
left.getConjuncts().equals(ImmutableSet.of(leftSide))),
+                                logicalFilter().when(right -> 
right.getConjuncts().equals(ImmutableSet.of(rightSide)))
+                        ));
     }
 
     @Test
-    public void bothSideToOneSide() {
+    void bothSideToOneSide() {
         bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, true);
         bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, true);
         bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, false);
@@ -157,45 +147,37 @@ public class PushdownJoinOtherConditionTest {
     }
 
     private void bothSideToOneSide(JoinType joinType, boolean testRight) {
-
-        Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), 
Literal.of(18));
-        Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), 
Literal.of(60));
+        Expression pushSide = new GreaterThan(rStudentSlots.get(1), 
Literal.of(18));
+        Expression reserveSide = new GreaterThan(rScoreSlots.get(2), 
Literal.of(60));
         List<Expression> condition = ImmutableList.of(pushSide, reserveSide);
 
-        Plan left = rStudent;
-        Plan right = rScore;
+        LogicalOlapScan left = rStudent;
+        LogicalOlapScan right = rScore;
         if (testRight) {
             left = rScore;
             right = rStudent;
         }
 
-        Plan join = new LogicalJoin<>(joinType, 
ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), 
left, right);
-        Plan root = new LogicalProject<>(Lists.newArrayList(), join);
+        LogicalPlan root = new LogicalPlanBuilder(left)
+                .join(right, joinType, ExpressionUtils.EMPTY_CONDITION, 
condition)
+                .project(Lists.newArrayList())
+                .build();
 
-        Memo memo = rewrite(root);
-        Group rootGroup = memo.getRoot();
+        PlanChecker planChecker = 
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new PushdownJoinOtherCondition());
 
-        Plan shouldJoin = rootGroup.getLogicalExpression()
-                .child(0).getLogicalExpression().getPlan();
-        Plan shouldFilter = rootGroup.getLogicalExpression()
-                
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
-        Plan shouldScan = rootGroup.getLogicalExpression()
-                
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
         if (testRight) {
-            shouldFilter = rootGroup.getLogicalExpression()
-                    
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
-            shouldScan = rootGroup.getLogicalExpression()
-                    
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
+            planChecker.matches(
+                    logicalJoin(
+                            logicalOlapScan(),
+                            logicalFilter().when(filter -> 
filter.getConjuncts().equals(ImmutableSet.of(pushSide)))
+                    ));
+        } else {
+            planChecker.matches(
+                    logicalJoin(
+                            logicalFilter().when(filter -> 
filter.getConjuncts().equals(ImmutableSet.of(pushSide))),
+                            logicalOlapScan()
+                    ));
         }
-
-        Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
-        Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
-        Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
-        LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
-        Assertions.assertEquals(ImmutableSet.of(pushSide), 
actualFilter.getConjuncts());
-    }
-
-    private Memo rewrite(Plan plan) {
-        return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new 
PushdownJoinOtherCondition());
     }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimitTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimitTest.java
index 77bf76af4c..d66ede574f 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimitTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimitTest.java
@@ -27,10 +27,10 @@ import org.apache.doris.nereids.util.PlanConstructor;
 import com.google.common.collect.ImmutableList;
 import org.junit.jupiter.api.Test;
 
-public class PushdownProjectThroughLimitTest implements 
MemoPatternMatchSupported {
+class PushdownProjectThroughLimitTest implements MemoPatternMatchSupported {
 
     @Test
-    public void testPushdownProjectThroughLimit() {
+    void testPushdownProjectThroughLimit() {
         LogicalPlan project = new 
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
                 .limit(1, 1)
                 .project(ImmutableList.of(0)) // id
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SplitLimitTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SplitLimitTest.java
index 174f5a90b4..3948406abd 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SplitLimitTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SplitLimitTest.java
@@ -17,25 +17,32 @@
 
 package org.apache.doris.nereids.rules.rewrite.logical;
 
-import org.apache.doris.nereids.trees.plans.LimitPhase;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.MemoTestUtils;
 import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
 
 import org.junit.jupiter.api.Test;
 
-public class SplitLimitTest {
+/**
+ * Tests for {@link SplitLimit}.
+ */
+class SplitLimitTest implements MemoPatternMatchSupported {
     private final LogicalOlapScan scan1 = 
PlanConstructor.newLogicalOlapScan(0, "t1", 0);
 
     @Test
     void testSplitLimit() {
-        Plan plan = new LogicalLimit<>(0, 0, LimitPhase.ORIGIN, scan1);
-        plan = PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
-                .rewrite()
-                .getPlan();
-        plan.anyMatch(x -> x instanceof LogicalLimit && ((LogicalLimit<?>) 
x).isSplit());
+        LogicalPlan limit = new LogicalPlanBuilder(scan1)
+                .limit(0, 0)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), limit)
+                .applyTopDown(new SplitLimit())
+                .matches(
+                        globalLogicalLimit(localLogicalLimit())
+                );
     }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
index f49711aa89..ce713c2f4d 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java
@@ -55,6 +55,7 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
 import org.apache.doris.planner.PlanFragment;
 import org.apache.doris.qe.ConnectContext;
 import org.apache.doris.qe.OriginStatement;
@@ -126,6 +127,12 @@ public class PlanChecker {
         return applyTopDown(ruleFactory.buildRules());
     }
 
+    public PlanChecker applyTopDown(CustomRewriter customRewriter) {
+        cascadesContext.topDownRewrite(customRewriter);
+        MemoValidator.validate(cascadesContext.getMemo());
+        return this;
+    }
+
     public PlanChecker applyTopDown(List<Rule> rule) {
         cascadesContext.topDownRewrite(rule);
         MemoValidator.validate(cascadesContext.getMemo());


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to