This is an automated email from the ASF dual-hosted git repository.

englefly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 62020018bcb [opt](nereids) support extract join multiple tables 
(#51569)
62020018bcb is described below

commit 62020018bcbc6e01ac076fbff6ca1e0828c603ca
Author: yujun <[email protected]>
AuthorDate: Wed Jun 11 22:12:55 2025 +0800

    [opt](nereids) support extract join multiple tables (#51569)
    
    ### What problem does this PR solve?
    
    For rule ExtractSingleTableExpressionFromDisjunction, it will extract
    every single table's expression for LogicalFilter and LogicalJoin. But
    this is not enough, for join, it should support extract multiple tables:
    a expression for left tables, and a right expression for right tables.
---
 ...xtractSingleTableExpressionFromDisjunction.java | 45 ++++++++----
 .../nereids/trees/expressions/NamedExpression.java |  4 ++
 ...ctSingleTableExpressionFromDisjunctionTest.java | 81 ++++++++++++++++++++++
 3 files changed, 115 insertions(+), 15 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
index c3c64c9e076..d56d60bc5e3 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
@@ -65,7 +65,8 @@ public class ExtractSingleTableExpressionFromDisjunction 
implements RewriteRuleF
     public List<Rule> buildRules() {
         return ImmutableList.of(
                 logicalFilter().then(filter -> {
-                    List<Expression> dependentPredicates = 
extractDependentConjuncts(filter.getConjuncts());
+                    List<Expression> dependentPredicates = 
extractDependentConjuncts(filter.getConjuncts(),
+                            Lists.newArrayList());
                     if (dependentPredicates.isEmpty()) {
                         return null;
                     }
@@ -78,8 +79,14 @@ public class ExtractSingleTableExpressionFromDisjunction 
implements RewriteRuleF
                     return new LogicalFilter<>(newPredicates, filter.child());
                 
}).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION),
                 logicalJoin().when(join -> 
ALLOW_JOIN_TYPE.contains(join.getJoinType())).then(join -> {
+                    List<Set<String>> qualifierBatches = Lists.newArrayList();
+                    // for join, also extract multiple tables: left tables and 
right tables
+                    
qualifierBatches.add(join.left().getOutputSet().stream().map(Slot::getJoinQualifier)
+                            .collect(Collectors.toSet()));
+                    
qualifierBatches.add(join.right().getOutputSet().stream().map(Slot::getJoinQualifier)
+                            .collect(Collectors.toSet()));
                     List<Expression> dependentOtherPredicates = 
extractDependentConjuncts(
-                            ImmutableSet.copyOf(join.getOtherJoinConjuncts()));
+                            ImmutableSet.copyOf(join.getOtherJoinConjuncts()), 
qualifierBatches);
                     if (dependentOtherPredicates.isEmpty()) {
                         return null;
                     }
@@ -95,7 +102,7 @@ public class ExtractSingleTableExpressionFromDisjunction 
implements RewriteRuleF
                 
}).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION));
     }
 
-    private List<Expression> extractDependentConjuncts(Set<Expression> 
conjuncts) {
+    private List<Expression> extractDependentConjuncts(Set<Expression> 
conjuncts, List<Set<String>> qualifierBatches) {
         List<Expression> dependentPredicates = Lists.newArrayList();
         for (Expression conjunct : conjuncts) {
             // conjunct=(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY')
@@ -112,11 +119,25 @@ public class ExtractSingleTableExpressionFromDisjunction 
implements RewriteRuleF
             Set<String> qualifiers = disjuncts.get(0).getInputSlots().stream()
                     .map(slot -> String.join(".", slot.getQualifier()))
                     .collect(Collectors.toCollection(Sets::newLinkedHashSet));
+            List<Set<String>> includeSingleQualifierBatches = 
Lists.newArrayListWithExpectedSize(
+                    qualifiers.size() + qualifierBatches.size());
+            // extract single table's expression
             for (String qualifier : qualifiers) {
+                includeSingleQualifierBatches.add(ImmutableSet.of(qualifier));
+            }
+            // for join, extract left tables and right tables
+            for (Set<String> batch : qualifierBatches) {
+                Set<String> newBatch = 
batch.stream().filter(qualifiers::contains).collect(Collectors.toSet());
+                // if newBatch.size == 1, then it had put into 
includeSingleQualifierBatches
+                if (newBatch.size() > 1) {
+                    includeSingleQualifierBatches.add(newBatch);
+                }
+            }
+            for (Set<String> batch : includeSingleQualifierBatches) {
                 List<Expression> extractForAll = Lists.newArrayList();
                 boolean success = true;
                 for (Expression expr : disjuncts) {
-                    Optional<Expression> extracted = 
extractSingleTableExpression(expr, qualifier);
+                    Optional<Expression> extracted = 
extractMultipleTableExpression(expr, batch);
                     if (!extracted.isPresent()) {
                         // extract failed
                         success = false;
@@ -136,7 +157,7 @@ public class ExtractSingleTableExpressionFromDisjunction 
implements RewriteRuleF
     // extract some conjucts from expr, all slots of the extracted conjunct 
comes from the table referred by qualifier.
     // example: expr=(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY'), 
qualifier="n1."
     // output: n1.n_name = 'FRANCE'
-    private Optional<Expression> extractSingleTableExpression(Expression expr, 
String qualifier) {
+    private Optional<Expression> extractMultipleTableExpression(Expression 
expr, Set<String> qualifiers) {
         // suppose the qualifier is table T, then the process steps are as 
follow:
         // 1. split the expression into conjunctions: c1 and c2 and c3 and ...
         // 2. for each conjunction ci, suppose its extract is Ei:
@@ -158,14 +179,14 @@ public class ExtractSingleTableExpressionFromDisjunction 
implements RewriteRuleF
         List<Expression> output = Lists.newArrayList();
         List<Expression> conjuncts = ExpressionUtils.extractConjunction(expr);
         for (Expression conjunct : conjuncts) {
-            if (isSingleTableExpression(conjunct, qualifier)) {
+            if (isTableExpression(conjunct, qualifiers)) {
                 output.add(conjunct);
             } else if (conjunct instanceof Or) {
                 List<Expression> disjuncts = 
ExpressionUtils.extractDisjunction(conjunct);
                 List<Expression> extracted = 
Lists.newArrayListWithExpectedSize(disjuncts.size());
                 boolean success = true;
                 for (Expression disjunct : disjuncts) {
-                    Optional<Expression> extractedDisjunct = 
extractSingleTableExpression(disjunct, qualifier);
+                    Optional<Expression> extractedDisjunct = 
extractMultipleTableExpression(disjunct, qualifiers);
                     if (extractedDisjunct.isPresent()) {
                         
extracted.addAll(ExpressionUtils.extractDisjunction(extractedDisjunct.get()));
                     } else {
@@ -186,14 +207,8 @@ public class ExtractSingleTableExpressionFromDisjunction 
implements RewriteRuleF
         }
     }
 
-    private boolean isSingleTableExpression(Expression expr, String qualifier) 
{
+    private boolean isTableExpression(Expression expr, Set<String> qualifiers) 
{
         //TODO: cache getSlotQualifierAsString() result.
-        for (Slot slot : expr.getInputSlots()) {
-            String slotQualifier = String.join(".", slot.getQualifier());
-            if (!slotQualifier.equals(qualifier)) {
-                return false;
-            }
-        }
-        return true;
+        return expr.getInputSlots().stream().allMatch(slot -> 
qualifiers.contains(slot.getJoinQualifier()));
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NamedExpression.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NamedExpression.java
index d03669234cd..e363296fd72 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NamedExpression.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NamedExpression.java
@@ -48,6 +48,10 @@ public abstract class NamedExpression extends Expression {
         throw new UnboundException("qualifier");
     }
 
+    public String getJoinQualifier() {
+        return String.join(".", getQualifier());
+    }
+
     /**
      * Get qualified name of NamedExpression.
      *
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
index 27901e2db9f..d099315ed52 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
+import org.apache.doris.nereids.trees.expressions.Add;
 import org.apache.doris.nereids.trees.expressions.And;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
@@ -24,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.GreaterThan;
 import org.apache.doris.nereids.trees.expressions.LessThan;
 import org.apache.doris.nereids.trees.expressions.Or;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
 import org.apache.doris.nereids.trees.plans.JoinType;
@@ -45,22 +47,29 @@ import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.TestInstance;
 
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Set;
 
 @TestInstance(TestInstance.Lifecycle.PER_CLASS)
 public class ExtractSingleTableExpressionFromDisjunctionTest implements 
MemoPatternMatchSupported {
     Plan student;
+    Plan score;
     Plan course;
+    Plan salary;
     SlotReference courseCid;
     SlotReference courseName;
     SlotReference studentAge;
     SlotReference studentGender;
+    SlotReference scoreSid;
+    SlotReference salaryId;
 
     @BeforeAll
     public final void beforeAll() {
         student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), 
PlanConstructor.student, ImmutableList.of(""));
+        score = new LogicalOlapScan(PlanConstructor.getNextRelationId(), 
PlanConstructor.score, ImmutableList.of(""));
         course = new LogicalOlapScan(PlanConstructor.getNextRelationId(), 
PlanConstructor.course, ImmutableList.of(""));
+        salary = new LogicalOlapScan(PlanConstructor.getNextRelationId(), 
PlanConstructor.salary, ImmutableList.of(""));
         //select *
         //from student join course
         //where (course.cid=1 and student.age=10) or (student.gender = 0 and 
course.name='abc')
@@ -68,6 +77,8 @@ public class ExtractSingleTableExpressionFromDisjunctionTest 
implements MemoPatt
         courseName = (SlotReference) course.getOutput().get(1);
         studentAge = (SlotReference) student.getOutput().get(3);
         studentGender = (SlotReference) student.getOutput().get(1);
+        scoreSid = (SlotReference) score.getOutput().get(0);
+        salaryId = (SlotReference) salary.getOutput().get(0);
     }
     /**
      *(cid=1 and sage=10) or (sgender=1 and cname='abc')
@@ -265,4 +276,74 @@ public class 
ExtractSingleTableExpressionFromDisjunctionTest implements MemoPatt
 
         return conjuncts.size() == 3 && conjuncts.contains(or1) && 
conjuncts.contains(or2);
     }
+
+    @Test
+    void testExtractMultipleTables() {
+        Expression expr = new Or(
+                ExpressionUtils.and(
+                        new GreaterThan(studentAge, new IntegerLiteral(1)),
+                        new GreaterThan(courseCid, new IntegerLiteral(1)),
+                        new GreaterThan(scoreSid, new IntegerLiteral(1)),
+                        new GreaterThan(salaryId, new IntegerLiteral(1)),
+                        new EqualTo(new Add(studentAge, courseCid), new 
BigIntLiteral(100L)),
+                        new EqualTo(new Add(scoreSid, salaryId), new 
BigIntLiteral(100L))
+                ),
+                ExpressionUtils.and(
+                        new GreaterThan(studentAge, new IntegerLiteral(2)),
+                        new GreaterThan(courseCid, new IntegerLiteral(2)),
+                        new GreaterThan(scoreSid, new IntegerLiteral(2)),
+                        new GreaterThan(salaryId, new IntegerLiteral(2)),
+                        new EqualTo(new Add(studentAge, courseCid), new 
BigIntLiteral(200L)),
+                        new EqualTo(new Add(scoreSid, salaryId), new 
BigIntLiteral(200L))
+                )
+        );
+        Plan left = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course, 
null);
+        Plan right = new LogicalJoin<>(JoinType.CROSS_JOIN, score, salary, 
null);
+        Plan root = new LogicalJoin<>(JoinType.INNER_JOIN, 
ExpressionUtils.EMPTY_CONDITION,
+                Collections.singletonList(expr), left, right, null);
+
+        List<Expression> expectJoinConjuncts = Arrays.asList(
+                // origin expression
+                expr,
+
+                // four single table expression
+                new Or(new GreaterThan(studentAge, new IntegerLiteral(1)),
+                        new GreaterThan(studentAge, new IntegerLiteral(2))),
+                new Or(new GreaterThan(courseCid, new IntegerLiteral(1)),
+                        new GreaterThan(courseCid, new IntegerLiteral(2))),
+                new Or(new GreaterThan(scoreSid, new IntegerLiteral(1)),
+                        new GreaterThan(scoreSid, new IntegerLiteral(2))),
+                new Or(new GreaterThan(salaryId, new IntegerLiteral(1)),
+                        new GreaterThan(salaryId, new IntegerLiteral(2))),
+
+                // left tables
+                new Or(
+                        ExpressionUtils.and(
+                                new GreaterThan(studentAge, new 
IntegerLiteral(1)),
+                                new GreaterThan(courseCid, new 
IntegerLiteral(1)),
+                                new EqualTo(new Add(studentAge, courseCid), 
new BigIntLiteral(100L))),
+                        ExpressionUtils.and(
+                                new GreaterThan(studentAge, new 
IntegerLiteral(2)),
+                                new GreaterThan(courseCid, new 
IntegerLiteral(2)),
+                                new EqualTo(new Add(studentAge, courseCid), 
new BigIntLiteral(200L)))),
+
+                // right tables
+                new Or(
+                        ExpressionUtils.and(
+                                new GreaterThan(scoreSid, new 
IntegerLiteral(1)),
+                                new GreaterThan(salaryId, new 
IntegerLiteral(1)),
+                                new EqualTo(new Add(scoreSid, salaryId), new 
BigIntLiteral(100L))),
+                        ExpressionUtils.and(
+                                new GreaterThan(scoreSid, new 
IntegerLiteral(2)),
+                                new GreaterThan(salaryId, new 
IntegerLiteral(2)),
+                                new EqualTo(new Add(scoreSid, salaryId), new 
BigIntLiteral(200L))))
+        );
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new 
ExtractSingleTableExpressionFromDisjunction())
+                .matchesFromRoot(
+                        logicalJoin()
+                                .when(join -> 
expectJoinConjuncts.equals(join.getOtherJoinConjuncts()))
+                );
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to