This is an automated email from the ASF dual-hosted git repository.
englefly pushed a commit to branch tpc_preview
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/tpc_preview by this push:
new 8d077cc3da7 transform inner join to semi join (#58220)
8d077cc3da7 is described below
commit 8d077cc3da74ea619cda60a48a0e55bb68266e0d
Author: minghong <[email protected]>
AuthorDate: Fri Nov 21 14:12:28 2025 +0800
transform inner join to semi join (#58220)
### What problem does this PR solve?
Issue Number: close #xxx
Related PR: #xxx
Problem Summary:
### Release note
None
### Check List (For Author)
- Test <!-- At least one of them must be included. -->
- [ ] Regression test
- [ ] Unit Test
- [ ] Manual test (add detailed scripts or steps below)
- [ ] No need to test or manual test. Explain why:
- [ ] This is a refactor/code format and no logic has been changed.
- [ ] Previous test can cover this change.
- [ ] No code files have been changed.
- [ ] Other reason <!-- Add your reason? -->
- Behavior changed:
- [ ] No.
- [ ] Yes. <!-- Explain the behavior change -->
- Does this need documentation?
- [ ] No.
- [ ] Yes. <!-- Add document PR link here. eg:
https://github.com/apache/doris-website/pull/1214 -->
### Check List (For Reviewer who merge this PR)
- [ ] Confirm the release note
- [ ] Confirm test cases
- [ ] Confirm document
- [ ] Add branch pick label <!-- Add branch pick label that this PR
should merge into -->
---
.../doris/nereids/jobs/executor/Rewriter.java | 3 +
.../org/apache/doris/nereids/rules/RuleType.java | 1 +
.../rules/rewrite/AggInnerJoinToSemiJoin.java | 105 ++++++++++++
.../rules/rewrite/AggInnerJoinToSemiJoinTest.java | 185 +++++++++++++++++++++
4 files changed, 294 insertions(+)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index f283aaae139..7ab92737551 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -38,6 +38,7 @@ import
org.apache.doris.nereids.rules.rewrite.AddProjectForJoin;
import org.apache.doris.nereids.rules.rewrite.AddProjectForUniqueFunction;
import org.apache.doris.nereids.rules.rewrite.AdjustConjunctsReturnType;
import org.apache.doris.nereids.rules.rewrite.AdjustNullable;
+import org.apache.doris.nereids.rules.rewrite.AggInnerJoinToSemiJoin;
import
org.apache.doris.nereids.rules.rewrite.AggScalarSubQueryToWindowFunction;
import org.apache.doris.nereids.rules.rewrite.BuildAggForUnion;
import org.apache.doris.nereids.rules.rewrite.CTEInline;
@@ -603,6 +604,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
),
bottomUp(RuleSet.PUSH_DOWN_FILTERS)
),
+ topic("transform inner join to semi join",
+ bottomUp(new AggInnerJoinToSemiJoin())),
topic("infer predicate",
cascadesContext ->
cascadesContext.rewritePlanContainsTypes(
LogicalFilter.class, LogicalJoin.class,
LogicalSetOperation.class
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index f38cb30f674..49a5e5f5c61 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -286,6 +286,7 @@ public enum RuleType {
ELIMINATE_NOT_NULL(RuleTypeClass.REWRITE),
ELIMINATE_UNNECESSARY_PROJECT(RuleTypeClass.REWRITE),
RECORD_PLAN_FOR_MV_PRE_REWRITE(RuleTypeClass.REWRITE),
+ AGG_INNER_JOIN_TO_SEMI_JOIN(RuleTypeClass.REWRITE),
ELIMINATE_OUTER_JOIN(RuleTypeClass.REWRITE),
ELIMINATE_MARK_JOIN(RuleTypeClass.REWRITE),
ELIMINATE_GROUP_BY(RuleTypeClass.REWRITE),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggInnerJoinToSemiJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggInnerJoinToSemiJoin.java
new file mode 100644
index 00000000000..7b825026d95
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggInnerJoinToSemiJoin.java
@@ -0,0 +1,105 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+
+import java.util.Set;
+
+/**
+ * Rewrite pattern:
+ *
+ * <pre>
+ * Aggregate
+ * Project
+ * InnerJoin(left, right)
+ * </pre>
+ *
+ * into
+ *
+ * <pre>
+ * Aggregate
+ * Project
+ * SemiJoin(left, right)
+ * </pre>
+ *
+ * when the project only references slots from a single join child and the
+ * aggregate simply
+ * forwards those slots (i.e. no real aggregation expressions). The rewritten
+ * join becomes
+ * LEFT_SEMI or RIGHT_SEMI depending on which side supplies all projected
slots.
+ */
+public class AggInnerJoinToSemiJoin extends OneRewriteRuleFactory {
+ @Override
+ public Rule build() {
+ return logicalAggregate(logicalProject(logicalJoin()))
+ .when(agg -> patternCheck(agg))
+ .then(agg -> transform(agg))
+ .toRule(RuleType.AGG_INNER_JOIN_TO_SEMI_JOIN);
+ }
+
+ Plan transform(LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>>
agg) {
+ LogicalProject<LogicalJoin<Plan, Plan>> project = agg.child();
+ LogicalJoin<Plan, Plan> join = project.child();
+ Set<Slot> leftOutput = join.left().getOutputSet();
+ Set<Slot> rightOutput = join.right().getOutputSet();
+ Set<Slot> projectSlots = project.getInputSlots();
+ LogicalJoin<Plan, Plan> newJoin;
+ if (leftOutput.containsAll(projectSlots)) {
+ newJoin = join.withJoinType(JoinType.LEFT_SEMI_JOIN);
+ } else if (rightOutput.containsAll(projectSlots)) {
+ newJoin = join.withJoinType(JoinType.RIGHT_SEMI_JOIN);
+ } else {
+ return null;
+ }
+ return agg.withChildren(project.withChildren(newJoin));
+ }
+
+ boolean patternCheck(LogicalAggregate<LogicalProject<LogicalJoin<Plan,
Plan>>> agg) {
+ LogicalProject<LogicalJoin<Plan, Plan>> project = agg.child();
+ LogicalJoin<Plan, Plan> join = project.child();
+ // this is an inner join
+ if (join.getJoinType() != JoinType.INNER_JOIN || join.isMarkJoin()) {
+ return false;
+ }
+ // join only output left/right slots
+ Set<Slot> leftOutput = join.left().getOutputSet();
+ Set<Slot> rightOutput = join.right().getOutputSet();
+ Set<Slot> projectSlots = project.getInputSlots();
+ if (!leftOutput.containsAll(projectSlots) &&
!rightOutput.containsAll(projectSlots)) {
+ return false;
+ }
+
+ // no aggregate functions
+ for (Expression expr : agg.getOutputExpressions()) {
+ if (!(expr instanceof Slot)) {
+ return false;
+ }
+ }
+ return true;
+
+ }
+}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggInnerJoinToSemiJoinTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggInnerJoinToSemiJoinTest.java
new file mode 100644
index 00000000000..e0cb08b2f32
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggInnerJoinToSemiJoinTest.java
@@ -0,0 +1,185 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+import org.apache.doris.qe.ConnectContext;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Test for AggInnerJoinToSemiJoin rule.
+ */
+class AggInnerJoinToSemiJoinTest implements MemoPatternMatchSupported {
+ private LogicalOlapScan scan1;
+ private LogicalOlapScan scan2;
+
+ @BeforeEach
+ void setUp() throws Exception {
+ // clear id so that slot id keep consistent every running
+ ConnectContext.remove();
+ StatementScopeIdGenerator.clear();
+ scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
+ scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
+ }
+
+ @Test
+ void testRewriteToLeftSemiJoin() {
+ // Aggregate only references left side columns, should rewrite to
LEFT_SEMI_JOIN
+ List<NamedExpression> groupByExprs = scan1.getOutput().stream()
+ .collect(Collectors.toList());
+
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id =
t2.id
+ .projectExprs(ImmutableList.copyOf(scan1.getOutput()))
+ .aggGroupUsingIndex(ImmutableList.of(0, 1), groupByExprs)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new AggInnerJoinToSemiJoin())
+ .printlnTree()
+ .matches(logicalAggregate(
+ logicalProject(
+ logicalJoin().when(join ->
join.getJoinType().isLeftSemiJoin()))));
+ }
+
+ @Test
+ void testRewriteToRightSemiJoin() {
+ // Aggregate only references right side columns, should rewrite to
+ // RIGHT_SEMI_JOIN
+ List<NamedExpression> groupByExprs = scan2.getOutput().stream()
+ .collect(Collectors.toList());
+
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id =
t2.id
+ .projectExprs(ImmutableList.copyOf(scan2.getOutput()))
+ .aggGroupUsingIndex(ImmutableList.of(0, 1), groupByExprs)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new AggInnerJoinToSemiJoin())
+ .printlnTree()
+ .matches(logicalAggregate(
+ logicalProject(
+ logicalJoin().when(join ->
join.getJoinType().isRightSemiJoin()))));
+ }
+
+ @Test
+ void testNoRewriteWithAggFunction() {
+ // Aggregate has aggregate function (count), should NOT rewrite
+ List<NamedExpression> outputExprs =
ImmutableList.<NamedExpression>builder()
+ .add(scan1.getOutput().get(0))
+ .add(new Alias(new Count(), "cnt"))
+ .build();
+
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id =
t2.id
+ .projectExprs(ImmutableList.copyOf(scan1.getOutput()))
+ .aggGroupUsingIndex(ImmutableList.of(0), outputExprs)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new AggInnerJoinToSemiJoin())
+ .printlnTree()
+ .matches(logicalAggregate(
+ logicalProject(
+ logicalJoin().when(join ->
join.getJoinType().isInnerJoin()))));
+ }
+
+ @Test
+ void testNoRewriteWithBothSideColumns() {
+ // Aggregate references columns from both sides, should NOT rewrite
+ List<NamedExpression> projectExprs =
ImmutableList.<NamedExpression>builder()
+ .add(scan1.getOutput().get(0))
+ .add(scan2.getOutput().get(0))
+ .build();
+ List<NamedExpression> groupByExprs =
ImmutableList.<NamedExpression>builder()
+ .addAll(projectExprs)
+ .build();
+
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id =
t2.id
+ .projectExprs(projectExprs)
+ .aggGroupUsingIndex(ImmutableList.of(0, 1), groupByExprs)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new AggInnerJoinToSemiJoin())
+ .printlnTree()
+ .matches(logicalAggregate(
+ logicalProject(
+ logicalJoin().when(join ->
join.getJoinType().isInnerJoin()))));
+ }
+
+ @Test
+ void testNoRewriteWithLeftOuterJoin() {
+ // Left outer join, should NOT rewrite
+ List<NamedExpression> groupByExprs = scan1.getOutput().stream()
+ .collect(Collectors.toList());
+
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id
= t2.id
+ .projectExprs(ImmutableList.copyOf(scan1.getOutput()))
+ .aggGroupUsingIndex(ImmutableList.of(0, 1), groupByExprs)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new AggInnerJoinToSemiJoin())
+ .printlnTree()
+ .matches(logicalAggregate(
+ logicalProject(
+ logicalJoin().when(join ->
join.getJoinType().isLeftOuterJoin()))));
+ }
+
+ @Test
+ void testRewriteWithSimpleProjection() {
+ // Simple case: project left columns then aggregate
+ List<NamedExpression> projectExprs =
ImmutableList.<NamedExpression>of(scan1.getOutput().get(0));
+ List<NamedExpression> groupByExprs =
ImmutableList.<NamedExpression>copyOf(projectExprs);
+
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id =
t2.id
+ .projectExprs(projectExprs)
+ .aggGroupUsingIndex(ImmutableList.of(0), groupByExprs)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new AggInnerJoinToSemiJoin())
+ .printlnTree()
+ .matches(logicalAggregate(
+ logicalProject(
+ logicalJoin().when(join ->
join.getJoinType().isLeftSemiJoin()))));
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]