This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new aa1bd29172d branch-4.1: [fix](fe) Fix assert row join pushdown alias
handling #63892 (#63934)
aa1bd29172d is described below
commit aa1bd29172d46fcd2bb8e1c9f5845773deb2a5d5
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Mon Jun 1 15:21:01 2026 +0800
branch-4.1: [fix](fe) Fix assert row join pushdown alias handling #63892
(#63934)
Cherry-picked from #63892
Co-authored-by: morrySnow <[email protected]>
---
.../rules/rewrite/PushDownJoinOnAssertNumRows.java | 102 ++++++++++-----------
.../rewrite/PushDownJoinOnAssertNumRowsTest.java | 66 +++++++++++++
2 files changed, 112 insertions(+), 56 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java
index e52def0723c..d45cc5676fe 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java
@@ -80,28 +80,28 @@ public class PushDownJoinOnAssertNumRows extends
OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalJoin()
- .when(topJoin -> pattenCheck(topJoin))
- .then(topJoin -> pushDownAssertNumRowsJoin(topJoin))
+ .when(this::pattenCheck)
+ .then(this::pushDownAssertNumRowsJoin)
.toRule(RuleType.PUSH_DOWN_JOIN_ON_ASSERT_NUM_ROWS);
}
- private boolean pattenCheck(LogicalJoin topJoin) {
+ private boolean pattenCheck(LogicalJoin<?, ?> topJoin) {
// 1. right is LogicalAssertNumRows or
LogicalProject->LogicalAssertNumRows
// 2. left is join or project->join
// 3. only one join condition.
if (!topJoin.getJoinType().isInnerOrCrossJoin()) {
return false;
}
- LogicalJoin bottomJoin;
+ LogicalJoin<?, ?> bottomJoin;
Plan left = topJoin.left();
Plan right = topJoin.right();
if (!isAssertOneRowEqOrProjectAssertOneRowEq(right)) {
return false;
}
if (left instanceof LogicalJoin) {
- bottomJoin = (LogicalJoin) left;
+ bottomJoin = (LogicalJoin<?, ?>) left;
} else if (left instanceof LogicalProject && left.child(0) instanceof
LogicalJoin) {
- bottomJoin = (LogicalJoin) left.child(0);
+ bottomJoin = (LogicalJoin<?, ?>) left.child(0);
} else {
return false;
}
@@ -125,7 +125,7 @@ public class PushDownJoinOnAssertNumRows extends
OneRewriteRuleFactory {
plan = plan.child(0);
}
if (plan instanceof LogicalAssertNumRows) {
- AssertNumRowsElement assertNumRowsElement =
((LogicalAssertNumRows) plan).getAssertNumRowsElement();
+ AssertNumRowsElement assertNumRowsElement =
((LogicalAssertNumRows<?>) plan).getAssertNumRowsElement();
if (assertNumRowsElement.getAssertion() ==
AssertNumRowsElement.Assertion.EQ
|| assertNumRowsElement.getDesiredNumOfRows() == 1L) {
return true;
@@ -134,14 +134,14 @@ public class PushDownJoinOnAssertNumRows extends
OneRewriteRuleFactory {
return false;
}
- private boolean joinOnAssertOneRowEq(LogicalJoin join) {
+ private boolean joinOnAssertOneRowEq(LogicalJoin<?, ?> join) {
return isAssertOneRowEqOrProjectAssertOneRowEq(join.right())
|| isAssertOneRowEqOrProjectAssertOneRowEq(join.left());
}
- private Plan pushDownAssertNumRowsJoin(LogicalJoin topJoin) {
+ private Plan pushDownAssertNumRowsJoin(LogicalJoin<?, ?> topJoin) {
Plan assertBranch = topJoin.right();
- Expression condition = (Expression)
topJoin.getOtherJoinConjuncts().get(0);
+ Expression condition = topJoin.getOtherJoinConjuncts().get(0);
List<Alias> aliasUsedInConditionFromLeftProject = new ArrayList<>();
LogicalJoin<? extends Plan, ? extends Plan> bottomJoin;
if (topJoin.left() instanceof LogicalProject) {
@@ -160,59 +160,49 @@ public class PushDownJoinOnAssertNumRows extends
OneRewriteRuleFactory {
Plan bottomRight = bottomJoin.right();
List<Slot> conditionSlotsFromTopLeft =
condition.getInputSlots().stream()
- .filter(slot -> topJoin.left().getOutputSet().contains(slot))
+ .filter(slot -> bottomJoin.getOutputSet().contains(slot))
.collect(Collectors.toList());
+ // Nothing from the bottom join participates in this scalar-subquery
condition.
+ if (conditionSlotsFromTopLeft.isEmpty()) {
+ return null;
+ }
if (bottomLeft.getOutputSet().containsAll(conditionSlotsFromTopLeft)) {
- // push to bottomLeft
- Plan newBottomLeft;
- if (aliasUsedInConditionFromLeftProject.isEmpty()) {
- newBottomLeft = bottomLeft;
- } else {
- newBottomLeft =
projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottomLeft);
- }
- LogicalJoin<? extends Plan, ? extends Plan> newBottomJoin = new
LogicalJoin<>(
- topJoin.getJoinType(),
- topJoin.getHashJoinConjuncts(),
- topJoin.getOtherJoinConjuncts(),
- newBottomLeft,
- assertBranch,
- topJoin.getJoinReorderContext());
- LogicalJoin<? extends Plan, ? extends Plan> newTopJoin =
(LogicalJoin<? extends Plan, ? extends Plan>)
- bottomJoin.withChildren(newBottomJoin, bottomRight);
- if (topJoin.left() instanceof LogicalProject) {
- LogicalProject<? extends Plan> upperProject =
projectAliasOnPlan(
- aliasUsedInConditionFromLeftProject, topJoin.left());
- return upperProject.withChildren(newTopJoin);
- } else {
- return newTopJoin;
- }
+ return assembleNewJoin(bottomLeft, topJoin, bottomJoin,
bottomRight,
+ assertBranch, aliasUsedInConditionFromLeftProject, true);
} else if
(bottomRight.getOutputSet().containsAll(conditionSlotsFromTopLeft)) {
- Plan newBottomRight;
- if (aliasUsedInConditionFromLeftProject.isEmpty()) {
- newBottomRight = bottomRight;
- } else {
- newBottomRight =
projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottomRight);
- }
- LogicalJoin<? extends Plan, ? extends Plan> newBottomJoin = new
LogicalJoin<>(
- topJoin.getJoinType(),
- topJoin.getHashJoinConjuncts(),
- topJoin.getOtherJoinConjuncts(),
- newBottomRight,
- assertBranch,
- topJoin.getJoinReorderContext());
- LogicalJoin<? extends Plan, ? extends Plan> newTopJoin =
(LogicalJoin<? extends Plan, ? extends Plan>)
- bottomJoin.withChildren(bottomLeft, newBottomJoin);
- if (topJoin.left() instanceof LogicalProject) {
- LogicalProject<? extends Plan> upperProject =
projectAliasOnPlan(
- aliasUsedInConditionFromLeftProject, topJoin.left());
- return upperProject.withChildren(newTopJoin);
- } else {
- return newTopJoin;
- }
+ return assembleNewJoin(bottomRight, topJoin, bottomJoin,
bottomLeft,
+ assertBranch, aliasUsedInConditionFromLeftProject, false);
}
return null;
}
+ private Plan assembleNewJoin(Plan bottom, LogicalJoin<?, ?> topJoin,
LogicalJoin<?, ?> bottomJoin, Plan newTopChild,
+ Plan assertBranch, List<Alias>
aliasUsedInConditionFromLeftProject, boolean pushLeft) {
+ Plan newBottomChild;
+ if (aliasUsedInConditionFromLeftProject.isEmpty()) {
+ newBottomChild = bottom;
+ } else {
+ newBottomChild =
projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottom);
+ }
+ LogicalJoin<? extends Plan, ? extends Plan> newBottomJoin = new
LogicalJoin<>(
+ topJoin.getJoinType(),
+ topJoin.getHashJoinConjuncts(),
+ topJoin.getOtherJoinConjuncts(),
+ newBottomChild,
+ assertBranch,
+ topJoin.getJoinReorderContext());
+ LogicalJoin<? extends Plan, ? extends Plan> newTopJoin =
(LogicalJoin<? extends Plan, ? extends Plan>)
+ (pushLeft ? bottomJoin.withChildren(newBottomJoin, newTopChild)
+ : bottomJoin.withChildren(newTopChild, newBottomJoin));
+ if (topJoin.left() instanceof LogicalProject) {
+ LogicalProject<? extends Plan> upperProject = projectAliasOnPlan(
+ aliasUsedInConditionFromLeftProject, topJoin.left());
+ return upperProject.withChildren(newTopJoin);
+ } else {
+ return newTopJoin;
+ }
+ }
+
@VisibleForTesting
LogicalProject<? extends Plan> projectAliasOnPlan(List<Alias> projections,
Plan child) {
if (child instanceof LogicalProject) {
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java
index aded31bd18f..d241433a219 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java
@@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
+import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
@@ -253,6 +254,71 @@ class PushDownJoinOnAssertNumRowsTest implements
MemoPatternMatchSupported {
logicalOlapScan())));
}
+ /**
+ * Test push down when the top join condition uses an alias from the right
child
+ * of the bottom join. This covers the following shape:
+ *
+ * Before:
+ * topJoin(rhs_score < x)
+ * |-- Project(T1.id, T2.cid + 1 as rhs_score, ...)
+ * | `-- bottomJoin(T1.id = T2.sid)
+ * | |-- Scan(T1)
+ * | `-- Scan(T2)
+ * `-- LogicalAssertNumRows(output=(x, ...))
+ *
+ * After:
+ * Project(...)
+ * `-- bottomJoin(T1.id = T2.sid)
+ * |-- Scan(T1)
+ * `-- topJoin(rhs_score < x)
+ * |-- Project(T2.cid + 1 as rhs_score, ...)
+ * | `-- Scan(T2)
+ * `-- LogicalAssertNumRows(output=(x, ...))
+ */
+ @Test
+ void testPushDownWithProjectAliasFromRightChild() {
+ Plan oneRowRelation = new LogicalPlanBuilder(t3)
+ .limit(1)
+ .build();
+
+ AssertNumRowsElement assertElement = new AssertNumRowsElement(1, "",
Assertion.EQ);
+ LogicalAssertNumRows<Plan> assertNumRows = new
LogicalAssertNumRows<>(assertElement, oneRowRelation);
+
+ Expression bottomJoinCondition = new EqualTo(t1Slots.get(0),
t2Slots.get(0));
+
+ LogicalPlan bottomJoin = new LogicalPlanBuilder(t1)
+ .join(t2, JoinType.INNER_JOIN,
ImmutableList.of(bottomJoinCondition),
+ ImmutableList.of())
+ .build();
+
+ Expression addExpr = new Add(t2Slots.get(1), Literal.of(1));
+ Alias rhsScore = new Alias(addExpr, "rhs_score");
+
+ ImmutableList.Builder<NamedExpression> projectListBuilder =
ImmutableList.builder();
+ projectListBuilder.add(t1Slots.get(0));
+ projectListBuilder.add(t1Slots.get(1));
+ projectListBuilder.add(t2Slots.get(0));
+ projectListBuilder.add(rhsScore);
+
+ LogicalProject<Plan> project = new
LogicalProject<>(projectListBuilder.build(), bottomJoin);
+
+ Expression topJoinCondition = new LessThan(rhsScore.toSlot(),
t3Slots.get(0));
+
+ LogicalPlan root = new LogicalPlanBuilder(project)
+ .join(assertNumRows, JoinType.INNER_JOIN, ImmutableList.of(),
+ ImmutableList.of(topJoinCondition))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new PushDownJoinOnAssertNumRows())
+ .matches(logicalProject(
+ logicalJoin(
+ logicalOlapScan(),
+ logicalJoin(
+
logicalProject(logicalOlapScan()),
+
logicalAssertNumRows()))));
+ }
+
/**
* Test with CROSS JOIN type.
*/
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]