This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-4.1 by this push:
     new aa1bd29172d branch-4.1: [fix](fe) Fix assert row join pushdown alias 
handling #63892 (#63934)
aa1bd29172d is described below

commit aa1bd29172d46fcd2bb8e1c9f5845773deb2a5d5
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Mon Jun 1 15:21:01 2026 +0800

    branch-4.1: [fix](fe) Fix assert row join pushdown alias handling #63892 
(#63934)
    
    Cherry-picked from #63892
    
    Co-authored-by: morrySnow <[email protected]>
---
 .../rules/rewrite/PushDownJoinOnAssertNumRows.java | 102 ++++++++++-----------
 .../rewrite/PushDownJoinOnAssertNumRowsTest.java   |  66 +++++++++++++
 2 files changed, 112 insertions(+), 56 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java
index e52def0723c..d45cc5676fe 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java
@@ -80,28 +80,28 @@ public class PushDownJoinOnAssertNumRows extends 
OneRewriteRuleFactory {
     @Override
     public Rule build() {
         return logicalJoin()
-                .when(topJoin -> pattenCheck(topJoin))
-                .then(topJoin -> pushDownAssertNumRowsJoin(topJoin))
+                .when(this::pattenCheck)
+                .then(this::pushDownAssertNumRowsJoin)
                 .toRule(RuleType.PUSH_DOWN_JOIN_ON_ASSERT_NUM_ROWS);
     }
 
-    private boolean pattenCheck(LogicalJoin topJoin) {
+    private boolean pattenCheck(LogicalJoin<?, ?> topJoin) {
         // 1. right is LogicalAssertNumRows or 
LogicalProject->LogicalAssertNumRows
         // 2. left is join or project->join
         // 3. only one join condition.
         if (!topJoin.getJoinType().isInnerOrCrossJoin()) {
             return false;
         }
-        LogicalJoin bottomJoin;
+        LogicalJoin<?, ?> bottomJoin;
         Plan left = topJoin.left();
         Plan right = topJoin.right();
         if (!isAssertOneRowEqOrProjectAssertOneRowEq(right)) {
             return false;
         }
         if (left instanceof LogicalJoin) {
-            bottomJoin = (LogicalJoin) left;
+            bottomJoin = (LogicalJoin<?, ?>) left;
         } else if (left instanceof LogicalProject && left.child(0) instanceof 
LogicalJoin) {
-            bottomJoin = (LogicalJoin) left.child(0);
+            bottomJoin = (LogicalJoin<?, ?>) left.child(0);
         } else {
             return false;
         }
@@ -125,7 +125,7 @@ public class PushDownJoinOnAssertNumRows extends 
OneRewriteRuleFactory {
             plan = plan.child(0);
         }
         if (plan instanceof LogicalAssertNumRows) {
-            AssertNumRowsElement assertNumRowsElement = 
((LogicalAssertNumRows) plan).getAssertNumRowsElement();
+            AssertNumRowsElement assertNumRowsElement = 
((LogicalAssertNumRows<?>) plan).getAssertNumRowsElement();
             if (assertNumRowsElement.getAssertion() == 
AssertNumRowsElement.Assertion.EQ
                     || assertNumRowsElement.getDesiredNumOfRows() == 1L) {
                 return true;
@@ -134,14 +134,14 @@ public class PushDownJoinOnAssertNumRows extends 
OneRewriteRuleFactory {
         return false;
     }
 
-    private boolean joinOnAssertOneRowEq(LogicalJoin join) {
+    private boolean joinOnAssertOneRowEq(LogicalJoin<?, ?> join) {
         return isAssertOneRowEqOrProjectAssertOneRowEq(join.right())
                 || isAssertOneRowEqOrProjectAssertOneRowEq(join.left());
     }
 
-    private Plan pushDownAssertNumRowsJoin(LogicalJoin topJoin) {
+    private Plan pushDownAssertNumRowsJoin(LogicalJoin<?, ?> topJoin) {
         Plan assertBranch = topJoin.right();
-        Expression condition = (Expression) 
topJoin.getOtherJoinConjuncts().get(0);
+        Expression condition = topJoin.getOtherJoinConjuncts().get(0);
         List<Alias> aliasUsedInConditionFromLeftProject = new ArrayList<>();
         LogicalJoin<? extends Plan, ? extends Plan> bottomJoin;
         if (topJoin.left() instanceof LogicalProject) {
@@ -160,59 +160,49 @@ public class PushDownJoinOnAssertNumRows extends 
OneRewriteRuleFactory {
         Plan bottomRight = bottomJoin.right();
 
         List<Slot> conditionSlotsFromTopLeft = 
condition.getInputSlots().stream()
-                .filter(slot -> topJoin.left().getOutputSet().contains(slot))
+                .filter(slot -> bottomJoin.getOutputSet().contains(slot))
                 .collect(Collectors.toList());
+        // Nothing from the bottom join participates in this scalar-subquery 
condition.
+        if (conditionSlotsFromTopLeft.isEmpty()) {
+            return null;
+        }
         if (bottomLeft.getOutputSet().containsAll(conditionSlotsFromTopLeft)) {
-            // push to bottomLeft
-            Plan newBottomLeft;
-            if (aliasUsedInConditionFromLeftProject.isEmpty()) {
-                newBottomLeft = bottomLeft;
-            } else {
-                newBottomLeft = 
projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottomLeft);
-            }
-            LogicalJoin<? extends Plan, ? extends Plan> newBottomJoin = new 
LogicalJoin<>(
-                    topJoin.getJoinType(),
-                    topJoin.getHashJoinConjuncts(),
-                    topJoin.getOtherJoinConjuncts(),
-                    newBottomLeft,
-                    assertBranch,
-                    topJoin.getJoinReorderContext());
-            LogicalJoin<? extends Plan, ? extends Plan> newTopJoin = 
(LogicalJoin<? extends Plan, ? extends Plan>)
-                    bottomJoin.withChildren(newBottomJoin, bottomRight);
-            if (topJoin.left() instanceof LogicalProject) {
-                LogicalProject<? extends Plan> upperProject = 
projectAliasOnPlan(
-                        aliasUsedInConditionFromLeftProject, topJoin.left());
-                return upperProject.withChildren(newTopJoin);
-            } else {
-                return newTopJoin;
-            }
+            return assembleNewJoin(bottomLeft, topJoin, bottomJoin, 
bottomRight,
+                    assertBranch, aliasUsedInConditionFromLeftProject, true);
         } else if 
(bottomRight.getOutputSet().containsAll(conditionSlotsFromTopLeft)) {
-            Plan newBottomRight;
-            if (aliasUsedInConditionFromLeftProject.isEmpty()) {
-                newBottomRight = bottomRight;
-            } else {
-                newBottomRight = 
projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottomRight);
-            }
-            LogicalJoin<? extends Plan, ? extends Plan> newBottomJoin = new 
LogicalJoin<>(
-                    topJoin.getJoinType(),
-                    topJoin.getHashJoinConjuncts(),
-                    topJoin.getOtherJoinConjuncts(),
-                    newBottomRight,
-                    assertBranch,
-                    topJoin.getJoinReorderContext());
-            LogicalJoin<? extends Plan, ? extends Plan> newTopJoin = 
(LogicalJoin<? extends Plan, ? extends Plan>)
-                    bottomJoin.withChildren(bottomLeft, newBottomJoin);
-            if (topJoin.left() instanceof LogicalProject) {
-                LogicalProject<? extends Plan> upperProject = 
projectAliasOnPlan(
-                        aliasUsedInConditionFromLeftProject, topJoin.left());
-                return upperProject.withChildren(newTopJoin);
-            } else {
-                return newTopJoin;
-            }
+            return assembleNewJoin(bottomRight, topJoin, bottomJoin, 
bottomLeft,
+                    assertBranch, aliasUsedInConditionFromLeftProject, false);
         }
         return null;
     }
 
+    private Plan assembleNewJoin(Plan bottom, LogicalJoin<?, ?> topJoin, 
LogicalJoin<?, ?> bottomJoin, Plan newTopChild,
+            Plan assertBranch, List<Alias> 
aliasUsedInConditionFromLeftProject, boolean pushLeft) {
+        Plan newBottomChild;
+        if (aliasUsedInConditionFromLeftProject.isEmpty()) {
+            newBottomChild = bottom;
+        } else {
+            newBottomChild = 
projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottom);
+        }
+        LogicalJoin<? extends Plan, ? extends Plan> newBottomJoin = new 
LogicalJoin<>(
+                topJoin.getJoinType(),
+                topJoin.getHashJoinConjuncts(),
+                topJoin.getOtherJoinConjuncts(),
+                newBottomChild,
+                assertBranch,
+                topJoin.getJoinReorderContext());
+        LogicalJoin<? extends Plan, ? extends Plan> newTopJoin = 
(LogicalJoin<? extends Plan, ? extends Plan>)
+                (pushLeft ? bottomJoin.withChildren(newBottomJoin, newTopChild)
+                        : bottomJoin.withChildren(newTopChild, newBottomJoin));
+        if (topJoin.left() instanceof LogicalProject) {
+            LogicalProject<? extends Plan> upperProject = projectAliasOnPlan(
+                    aliasUsedInConditionFromLeftProject, topJoin.left());
+            return upperProject.withChildren(newTopJoin);
+        } else {
+            return newTopJoin;
+        }
+    }
+
     @VisibleForTesting
     LogicalProject<? extends Plan> projectAliasOnPlan(List<Alias> projections, 
Plan child) {
         if (child instanceof LogicalProject) {
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java
index aded31bd18f..d241433a219 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java
@@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.GreaterThan;
+import org.apache.doris.nereids.trees.expressions.LessThan;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
@@ -253,6 +254,71 @@ class PushDownJoinOnAssertNumRowsTest implements 
MemoPatternMatchSupported {
                                                         logicalOlapScan())));
     }
 
+    /**
+     * Test push down when the top join condition uses an alias from the right 
child
+     * of the bottom join. This covers the following shape:
+     *
+     * Before:
+     * topJoin(rhs_score < x)
+     * |-- Project(T1.id, T2.cid + 1 as rhs_score, ...)
+     * |   `-- bottomJoin(T1.id = T2.sid)
+     * |       |-- Scan(T1)
+     * |       `-- Scan(T2)
+     * `-- LogicalAssertNumRows(output=(x, ...))
+     *
+     * After:
+     * Project(...)
+     * `-- bottomJoin(T1.id = T2.sid)
+     *     |-- Scan(T1)
+     *     `-- topJoin(rhs_score < x)
+     *         |-- Project(T2.cid + 1 as rhs_score, ...)
+     *         |   `-- Scan(T2)
+     *         `-- LogicalAssertNumRows(output=(x, ...))
+     */
+    @Test
+    void testPushDownWithProjectAliasFromRightChild() {
+        Plan oneRowRelation = new LogicalPlanBuilder(t3)
+                .limit(1)
+                .build();
+
+        AssertNumRowsElement assertElement = new AssertNumRowsElement(1, "", 
Assertion.EQ);
+        LogicalAssertNumRows<Plan> assertNumRows = new 
LogicalAssertNumRows<>(assertElement, oneRowRelation);
+
+        Expression bottomJoinCondition = new EqualTo(t1Slots.get(0), 
t2Slots.get(0));
+
+        LogicalPlan bottomJoin = new LogicalPlanBuilder(t1)
+                .join(t2, JoinType.INNER_JOIN, 
ImmutableList.of(bottomJoinCondition),
+                                ImmutableList.of())
+                .build();
+
+        Expression addExpr = new Add(t2Slots.get(1), Literal.of(1));
+        Alias rhsScore = new Alias(addExpr, "rhs_score");
+
+        ImmutableList.Builder<NamedExpression> projectListBuilder = 
ImmutableList.builder();
+        projectListBuilder.add(t1Slots.get(0));
+        projectListBuilder.add(t1Slots.get(1));
+        projectListBuilder.add(t2Slots.get(0));
+        projectListBuilder.add(rhsScore);
+
+        LogicalProject<Plan> project = new 
LogicalProject<>(projectListBuilder.build(), bottomJoin);
+
+        Expression topJoinCondition = new LessThan(rhsScore.toSlot(), 
t3Slots.get(0));
+
+        LogicalPlan root = new LogicalPlanBuilder(project)
+                .join(assertNumRows, JoinType.INNER_JOIN, ImmutableList.of(),
+                                ImmutableList.of(topJoinCondition))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new PushDownJoinOnAssertNumRows())
+                .matches(logicalProject(
+                                logicalJoin(
+                                                logicalOlapScan(),
+                                                logicalJoin(
+                                                                
logicalProject(logicalOlapScan()),
+                                                                
logicalAssertNumRows()))));
+    }
+
     /**
      * Test with CROSS JOIN type.
      */


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to