This is an automated email from the ASF dual-hosted git repository.

englefly pushed a commit to branch tpc_preview
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/tpc_preview by this push:
     new 8d077cc3da7 transform inner join to semi join (#58220)
8d077cc3da7 is described below

commit 8d077cc3da74ea619cda60a48a0e55bb68266e0d
Author: minghong <[email protected]>
AuthorDate: Fri Nov 21 14:12:28 2025 +0800

    transform inner join to semi join (#58220)
    
    ### What problem does this PR solve?
    
    Issue Number: close #xxx
    
    Related PR: #xxx
    
    Problem Summary:
    
    ### Release note
    
    None
    
    ### Check List (For Author)
    
    - Test <!-- At least one of them must be included. -->
        - [ ] Regression test
        - [ ] Unit Test
        - [ ] Manual test (add detailed scripts or steps below)
        - [ ] No need to test or manual test. Explain why:
    - [ ] This is a refactor/code format and no logic has been changed.
            - [ ] Previous test can cover this change.
            - [ ] No code files have been changed.
            - [ ] Other reason <!-- Add your reason?  -->
    
    - Behavior changed:
        - [ ] No.
        - [ ] Yes. <!-- Explain the behavior change -->
    
    - Does this need documentation?
        - [ ] No.
    - [ ] Yes. <!-- Add document PR link here. eg:
    https://github.com/apache/doris-website/pull/1214 -->
    
    ### Check List (For Reviewer who merge this PR)
    
    - [ ] Confirm the release note
    - [ ] Confirm test cases
    - [ ] Confirm document
    - [ ] Add branch pick label <!-- Add branch pick label that this PR
    should merge into -->
---
 .../doris/nereids/jobs/executor/Rewriter.java      |   3 +
 .../org/apache/doris/nereids/rules/RuleType.java   |   1 +
 .../rules/rewrite/AggInnerJoinToSemiJoin.java      | 105 ++++++++++++
 .../rules/rewrite/AggInnerJoinToSemiJoinTest.java  | 185 +++++++++++++++++++++
 4 files changed, 294 insertions(+)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index f283aaae139..7ab92737551 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -38,6 +38,7 @@ import 
org.apache.doris.nereids.rules.rewrite.AddProjectForJoin;
 import org.apache.doris.nereids.rules.rewrite.AddProjectForUniqueFunction;
 import org.apache.doris.nereids.rules.rewrite.AdjustConjunctsReturnType;
 import org.apache.doris.nereids.rules.rewrite.AdjustNullable;
+import org.apache.doris.nereids.rules.rewrite.AggInnerJoinToSemiJoin;
 import 
org.apache.doris.nereids.rules.rewrite.AggScalarSubQueryToWindowFunction;
 import org.apache.doris.nereids.rules.rewrite.BuildAggForUnion;
 import org.apache.doris.nereids.rules.rewrite.CTEInline;
@@ -603,6 +604,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
                             ),
                             bottomUp(RuleSet.PUSH_DOWN_FILTERS)
                     ),
+                    topic("transform inner join to semi join",
+                            bottomUp(new AggInnerJoinToSemiJoin())),
                     topic("infer predicate",
                         cascadesContext -> 
cascadesContext.rewritePlanContainsTypes(
                                 LogicalFilter.class, LogicalJoin.class, 
LogicalSetOperation.class
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index f38cb30f674..49a5e5f5c61 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -286,6 +286,7 @@ public enum RuleType {
     ELIMINATE_NOT_NULL(RuleTypeClass.REWRITE),
     ELIMINATE_UNNECESSARY_PROJECT(RuleTypeClass.REWRITE),
     RECORD_PLAN_FOR_MV_PRE_REWRITE(RuleTypeClass.REWRITE),
+    AGG_INNER_JOIN_TO_SEMI_JOIN(RuleTypeClass.REWRITE),
     ELIMINATE_OUTER_JOIN(RuleTypeClass.REWRITE),
     ELIMINATE_MARK_JOIN(RuleTypeClass.REWRITE),
     ELIMINATE_GROUP_BY(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggInnerJoinToSemiJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggInnerJoinToSemiJoin.java
new file mode 100644
index 00000000000..7b825026d95
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggInnerJoinToSemiJoin.java
@@ -0,0 +1,105 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+
+import java.util.Set;
+
+/**
+ * Rewrite pattern:
+ *
+ * <pre>
+ *   Aggregate
+ *     Project
+ *       InnerJoin(left, right)
+ * </pre>
+ *
+ * into
+ *
+ * <pre>
+ *   Aggregate
+ *     Project
+ *       SemiJoin(left, right)
+ * </pre>
+ *
+ * when the project only references slots from a single join child and the
+ * aggregate simply
+ * forwards those slots (i.e. no real aggregation expressions). The rewritten
+ * join becomes
+ * LEFT_SEMI or RIGHT_SEMI depending on which side supplies all projected 
slots.
+ */
+public class AggInnerJoinToSemiJoin extends OneRewriteRuleFactory {
+    @Override
+    public Rule build() {
+        return logicalAggregate(logicalProject(logicalJoin()))
+                .when(agg -> patternCheck(agg))
+                .then(agg -> transform(agg))
+                .toRule(RuleType.AGG_INNER_JOIN_TO_SEMI_JOIN);
+    }
+
+    Plan transform(LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> 
agg) {
+        LogicalProject<LogicalJoin<Plan, Plan>> project = agg.child();
+        LogicalJoin<Plan, Plan> join = project.child();
+        Set<Slot> leftOutput = join.left().getOutputSet();
+        Set<Slot> rightOutput = join.right().getOutputSet();
+        Set<Slot> projectSlots = project.getInputSlots();
+        LogicalJoin<Plan, Plan> newJoin;
+        if (leftOutput.containsAll(projectSlots)) {
+            newJoin = join.withJoinType(JoinType.LEFT_SEMI_JOIN);
+        } else if (rightOutput.containsAll(projectSlots)) {
+            newJoin = join.withJoinType(JoinType.RIGHT_SEMI_JOIN);
+        } else {
+            return null;
+        }
+        return agg.withChildren(project.withChildren(newJoin));
+    }
+
+    boolean patternCheck(LogicalAggregate<LogicalProject<LogicalJoin<Plan, 
Plan>>> agg) {
+        LogicalProject<LogicalJoin<Plan, Plan>> project = agg.child();
+        LogicalJoin<Plan, Plan> join = project.child();
+        // this is an inner join
+        if (join.getJoinType() != JoinType.INNER_JOIN || join.isMarkJoin()) {
+            return false;
+        }
+        // join only output left/right slots
+        Set<Slot> leftOutput = join.left().getOutputSet();
+        Set<Slot> rightOutput = join.right().getOutputSet();
+        Set<Slot> projectSlots = project.getInputSlots();
+        if (!leftOutput.containsAll(projectSlots) && 
!rightOutput.containsAll(projectSlots)) {
+            return false;
+        }
+
+        // no aggregate functions
+        for (Expression expr : agg.getOutputExpressions()) {
+            if (!(expr instanceof Slot)) {
+                return false;
+            }
+        }
+        return true;
+
+    }
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggInnerJoinToSemiJoinTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggInnerJoinToSemiJoinTest.java
new file mode 100644
index 00000000000..e0cb08b2f32
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggInnerJoinToSemiJoinTest.java
@@ -0,0 +1,185 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+import org.apache.doris.qe.ConnectContext;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Test for AggInnerJoinToSemiJoin rule.
+ */
+class AggInnerJoinToSemiJoinTest implements MemoPatternMatchSupported {
+    private LogicalOlapScan scan1;
+    private LogicalOlapScan scan2;
+
+    @BeforeEach
+    void setUp() throws Exception {
+        // clear id so that slot id keep consistent every running
+        ConnectContext.remove();
+        StatementScopeIdGenerator.clear();
+        scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
+        scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
+    }
+
+    @Test
+    void testRewriteToLeftSemiJoin() {
+        // Aggregate only references left side columns, should rewrite to 
LEFT_SEMI_JOIN
+        List<NamedExpression> groupByExprs = scan1.getOutput().stream()
+                .collect(Collectors.toList());
+
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = 
t2.id
+                .projectExprs(ImmutableList.copyOf(scan1.getOutput()))
+                .aggGroupUsingIndex(ImmutableList.of(0, 1), groupByExprs)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new AggInnerJoinToSemiJoin())
+                .printlnTree()
+                .matches(logicalAggregate(
+                        logicalProject(
+                                logicalJoin().when(join -> 
join.getJoinType().isLeftSemiJoin()))));
+    }
+
+    @Test
+    void testRewriteToRightSemiJoin() {
+        // Aggregate only references right side columns, should rewrite to
+        // RIGHT_SEMI_JOIN
+        List<NamedExpression> groupByExprs = scan2.getOutput().stream()
+                .collect(Collectors.toList());
+
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = 
t2.id
+                .projectExprs(ImmutableList.copyOf(scan2.getOutput()))
+                .aggGroupUsingIndex(ImmutableList.of(0, 1), groupByExprs)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new AggInnerJoinToSemiJoin())
+                .printlnTree()
+                .matches(logicalAggregate(
+                        logicalProject(
+                                logicalJoin().when(join -> 
join.getJoinType().isRightSemiJoin()))));
+    }
+
+    @Test
+    void testNoRewriteWithAggFunction() {
+        // Aggregate has aggregate function (count), should NOT rewrite
+        List<NamedExpression> outputExprs = 
ImmutableList.<NamedExpression>builder()
+                .add(scan1.getOutput().get(0))
+                .add(new Alias(new Count(), "cnt"))
+                .build();
+
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = 
t2.id
+                .projectExprs(ImmutableList.copyOf(scan1.getOutput()))
+                .aggGroupUsingIndex(ImmutableList.of(0), outputExprs)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new AggInnerJoinToSemiJoin())
+                .printlnTree()
+                .matches(logicalAggregate(
+                        logicalProject(
+                                logicalJoin().when(join -> 
join.getJoinType().isInnerJoin()))));
+    }
+
+    @Test
+    void testNoRewriteWithBothSideColumns() {
+        // Aggregate references columns from both sides, should NOT rewrite
+        List<NamedExpression> projectExprs = 
ImmutableList.<NamedExpression>builder()
+                .add(scan1.getOutput().get(0))
+                .add(scan2.getOutput().get(0))
+                .build();
+        List<NamedExpression> groupByExprs = 
ImmutableList.<NamedExpression>builder()
+                .addAll(projectExprs)
+                .build();
+
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = 
t2.id
+                .projectExprs(projectExprs)
+                .aggGroupUsingIndex(ImmutableList.of(0, 1), groupByExprs)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new AggInnerJoinToSemiJoin())
+                .printlnTree()
+                .matches(logicalAggregate(
+                        logicalProject(
+                                logicalJoin().when(join -> 
join.getJoinType().isInnerJoin()))));
+    }
+
+    @Test
+    void testNoRewriteWithLeftOuterJoin() {
+        // Left outer join, should NOT rewrite
+        List<NamedExpression> groupByExprs = scan1.getOutput().stream()
+                .collect(Collectors.toList());
+
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id 
= t2.id
+                .projectExprs(ImmutableList.copyOf(scan1.getOutput()))
+                .aggGroupUsingIndex(ImmutableList.of(0, 1), groupByExprs)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new AggInnerJoinToSemiJoin())
+                .printlnTree()
+                .matches(logicalAggregate(
+                        logicalProject(
+                                logicalJoin().when(join -> 
join.getJoinType().isLeftOuterJoin()))));
+    }
+
+    @Test
+    void testRewriteWithSimpleProjection() {
+        // Simple case: project left columns then aggregate
+        List<NamedExpression> projectExprs = 
ImmutableList.<NamedExpression>of(scan1.getOutput().get(0));
+        List<NamedExpression> groupByExprs = 
ImmutableList.<NamedExpression>copyOf(projectExprs);
+
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = 
t2.id
+                .projectExprs(projectExprs)
+                .aggGroupUsingIndex(ImmutableList.of(0), groupByExprs)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new AggInnerJoinToSemiJoin())
+                .printlnTree()
+                .matches(logicalAggregate(
+                        logicalProject(
+                                logicalJoin().when(join -> 
join.getJoinType().isLeftSemiJoin()))));
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to