This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.1 by this push:
     new 8fcce4f591c [improvement](nereids) support extract from disjunction in 
join on condition (#38479) (#43670)
8fcce4f591c is described below

commit 8fcce4f591c2b3ffbf4d67b762c75281a56f6edd
Author: feiniaofeiafei <[email protected]>
AuthorDate: Sat Nov 16 17:15:39 2024 +0800

    [improvement](nereids) support extract from disjunction in join on 
condition (#38479) (#43670)
    
    cherry-pick #38479 to branch-2.1
---
 ...xtractSingleTableExpressionFromDisjunction.java | 54 +++++++++----
 ...ctSingleTableExpressionFromDisjunctionTest.java | 36 +++++++++
 .../extract_from_disjunction_in_join.out           | 94 ++++++++++++++++++++++
 .../push_down_filter_through_window.out            |  0
 .../extract_from_disjunction_in_join.groovy        | 83 +++++++++++++++++++
 .../push_down_filter_through_window.groovy         |  0
 6 files changed, 251 insertions(+), 16 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
index 2f8e1404b71..fe2d7072ef5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
@@ -21,9 +21,11 @@ import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
 
@@ -50,24 +52,44 @@ import java.util.Set;
  * 3. In old optimizer, there is `InferFilterRule` generates redundancy 
expressions. Its Nereid counterpart also need
  * `RemoveRedundantExpression`.
  * <p>
- * TODO: This rule just match filter, but it could be applied to inner/cross 
join condition.
  */
-public class ExtractSingleTableExpressionFromDisjunction extends 
OneRewriteRuleFactory {
+public class ExtractSingleTableExpressionFromDisjunction implements 
RewriteRuleFactory {
+    private static final ImmutableSet<JoinType> ALLOW_JOIN_TYPE = 
ImmutableSet.of(JoinType.INNER_JOIN,
+            JoinType.LEFT_OUTER_JOIN, JoinType.RIGHT_OUTER_JOIN, 
JoinType.LEFT_SEMI_JOIN, JoinType.RIGHT_SEMI_JOIN,
+            JoinType.LEFT_ANTI_JOIN, JoinType.RIGHT_ANTI_JOIN, 
JoinType.CROSS_JOIN, JoinType.FULL_OUTER_JOIN);
+
     @Override
-    public Rule build() {
-        return logicalFilter().then(filter -> {
-            List<Expression> dependentPredicates = 
extractDependentConjuncts(filter.getConjuncts());
-            if (dependentPredicates.isEmpty()) {
-                return null;
-            }
-            Set<Expression> newPredicates = ImmutableSet.<Expression>builder()
-                    .addAll(filter.getConjuncts())
-                    .addAll(dependentPredicates).build();
-            if (newPredicates.size() == filter.getConjuncts().size()) {
-                return null;
-            }
-            return new LogicalFilter<>(newPredicates, filter.child());
-        }).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION);
+    public List<Rule> buildRules() {
+        return ImmutableList.of(
+                logicalFilter().then(filter -> {
+                    List<Expression> dependentPredicates = 
extractDependentConjuncts(filter.getConjuncts());
+                    if (dependentPredicates.isEmpty()) {
+                        return null;
+                    }
+                    Set<Expression> newPredicates = 
ImmutableSet.<Expression>builder()
+                            .addAll(filter.getConjuncts())
+                            .addAll(dependentPredicates).build();
+                    if (newPredicates.size() == filter.getConjuncts().size()) {
+                        return null;
+                    }
+                    return new LogicalFilter<>(newPredicates, filter.child());
+                
}).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION),
+                logicalJoin().when(join -> 
ALLOW_JOIN_TYPE.contains(join.getJoinType())).then(join -> {
+                    List<Expression> dependentOtherPredicates = 
extractDependentConjuncts(
+                            ImmutableSet.copyOf(join.getOtherJoinConjuncts()));
+                    if (dependentOtherPredicates.isEmpty()) {
+                        return null;
+                    }
+                    Set<Expression> newOtherPredicates = 
ImmutableSet.<Expression>builder()
+                            .addAll(join.getOtherJoinConjuncts())
+                            .addAll(dependentOtherPredicates).build();
+                    if (newOtherPredicates.size() == 
join.getOtherJoinConjuncts().size()) {
+                        return null;
+                    }
+                    return join.withJoinConjuncts(join.getHashJoinConjuncts(),
+                            ImmutableList.copyOf(newOtherPredicates),
+                            join.getMarkJoinConjuncts(), 
join.getJoinReorderContext());
+                
}).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION));
     }
 
     private List<Expression> extractDependentConjuncts(Set<Expression> 
conjuncts) {
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
index fc55f473ee6..39706d39f2c 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
@@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.MemoTestUtils;
 import org.apache.doris.nereids.util.PlanChecker;
@@ -41,6 +42,7 @@ import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.TestInstance;
 
+import java.util.List;
 import java.util.Set;
 
 @TestInstance(TestInstance.Lifecycle.PER_CLASS)
@@ -179,4 +181,38 @@ public class 
ExtractSingleTableExpressionFromDisjunctionTest implements MemoPatt
 
         return conjuncts.size() == 2 && conjuncts.contains(or);
     }
+
+    /**
+     * test join otherJoinReorderContext
+     *(cid=1 and sage=10) or sgender=1
+     * =>
+     * (sage=10 or sgender=1)
+     */
+    @Test
+    public void testExtract4() {
+        Expression expr = new Or(
+                new And(
+                        new EqualTo(courseCid, new IntegerLiteral(1)),
+                        new EqualTo(studentAge, new IntegerLiteral(10))
+                ),
+                new EqualTo(studentGender, new IntegerLiteral(1))
+        );
+        Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, 
ExpressionUtils.EMPTY_CONDITION, ImmutableList.of(expr),
+                student, course, null);
+        PlanChecker.from(MemoTestUtils.createConnectContext(), join)
+                .applyTopDown(new 
ExtractSingleTableExpressionFromDisjunction())
+                .matchesFromRoot(
+                        logicalJoin()
+                                .when(j -> 
verifySingleTableExpression4(j.getOtherJoinConjuncts()))
+                );
+        Assertions.assertNotNull(studentGender);
+    }
+
+    private boolean verifySingleTableExpression4(List<Expression> conjuncts) {
+        Expression or = new Or(
+                new EqualTo(studentAge, new IntegerLiteral(10)),
+                new EqualTo(studentGender, new IntegerLiteral(1))
+        );
+        return conjuncts.size() == 2 && conjuncts.contains(or);
+    }
 }
diff --git 
a/regression-test/data/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.out
 
b/regression-test/data/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.out
new file mode 100644
index 00000000000..9077ecb24b9
--- /dev/null
+++ 
b/regression-test/data/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.out
@@ -0,0 +1,94 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !left_semi --
+PhysicalResultSink
+--hashJoin[LEFT_SEMI_JOIN] hashCondition=((t1.b = t2.b)) 
otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))))
+----filter(a IN (1, 2))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+----filter(a IN (8, 9))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+
+-- !right_semi --
+PhysicalResultSink
+--hashJoin[LEFT_SEMI_JOIN] hashCondition=((t1.b = t2.b)) 
otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))))
+----filter(a IN (8, 9))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+----filter(a IN (1, 2))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+
+-- !left --
+PhysicalResultSink
+--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.b = t2.b)) 
otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) 
and a IN (1, 2))
+----PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+----filter(a IN (8, 9))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+
+-- !right --
+PhysicalResultSink
+--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.b = t2.b)) 
otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) 
and a IN (8, 9))
+----PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+----filter(a IN (1, 2))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+
+-- !left_anti --
+PhysicalResultSink
+--hashJoin[LEFT_ANTI_JOIN] hashCondition=((t1.b = t2.b)) 
otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) 
and a IN (1, 2))
+----PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+----filter(a IN (8, 9))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+
+-- !right_anti --
+PhysicalResultSink
+--hashJoin[LEFT_ANTI_JOIN] hashCondition=((t1.b = t2.b)) 
otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) 
and a IN (8, 9))
+----PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+----filter(a IN (1, 2))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+
+-- !inner --
+PhysicalResultSink
+--hashJoin[INNER_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 
9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))))
+----filter(a IN (1, 2))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+----filter(a IN (8, 9))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+
+-- !outer --
+PhysicalResultSink
+--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.b = t2.b)) 
otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) 
and a IN (1, 2))
+----filter((t1.c = 3))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+----filter(a IN (8, 9))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+
+-- !left_semi_res --
+1
+2
+
+-- !right_semi_res --
+8
+9
+
+-- !left_res --
+1
+2
+3
+
+-- !right_res --
+\N
+1
+2
+
+-- !left_anti_res --
+3
+
+-- !right_anti_res --
+7
+
+-- !inner_res --
+1
+2
+
+-- !outer_res --
+1
+2
+3
+
diff --git 
a/regression-test/data/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.out
 
b/regression-test/data/nereids_rules_p0/push_down_filter/push_down_filter_through_window.out
similarity index 100%
rename from 
regression-test/data/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.out
rename to 
regression-test/data/nereids_rules_p0/push_down_filter/push_down_filter_through_window.out
diff --git 
a/regression-test/suites/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.groovy
 
b/regression-test/suites/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.groovy
new file mode 100644
index 00000000000..858f39e5e65
--- /dev/null
+++ 
b/regression-test/suites/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.groovy
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("extract_from_disjunction_in_join") {
+    sql "SET enable_nereids_planner=true"
+    sql "SET enable_fallback_to_original_planner=false"
+    sql "set ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"
+    sql "set disable_nereids_rules=PRUNE_EMPTY_PARTITION"
+    sql "set runtime_filter_mode=OFF"
+
+
+    sql "drop table if exists extract_from_disjunction_in_join_t1"
+    sql "drop table if exists extract_from_disjunction_in_join_t2"
+    sql """
+    CREATE TABLE `extract_from_disjunction_in_join_t1` (
+      `a` INT NULL,
+      `b` VARCHAR(10) NULL,
+      `c` INT NULL,
+      `d` INT NULL
+    ) ENGINE=OLAP
+    DUPLICATE KEY(`a`, `b`)
+    DISTRIBUTED BY RANDOM BUCKETS AUTO
+    PROPERTIES (
+    "replication_allocation" = "tag.location.default: 1"
+    );
+    """
+    sql """
+    CREATE TABLE `extract_from_disjunction_in_join_t2` (
+      `a` INT NULL,
+      `b` VARCHAR(10) NULL,
+      `c` INT NULL,
+      `d` INT NULL
+    ) ENGINE=OLAP
+    DUPLICATE KEY(`a`, `b`)
+    DISTRIBUTED BY RANDOM BUCKETS AUTO
+    PROPERTIES (
+    "replication_allocation" = "tag.location.default: 1"
+    );"""
+
+    sql "insert into extract_from_disjunction_in_join_t1 
values(1,'d2',3,5),(2,'d2',3,5),(3,'d2',3,5);"
+    sql "insert into extract_from_disjunction_in_join_t2 
values(7,'d2',2,2),(8,'d2',2,2),(9,'d2',2,2);"
+    qt_left_semi """explain shape plan
+    select * from extract_from_disjunction_in_join_t1 t1 left semi join 
extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || 
t1.a=2 && t2.a=8);"""
+    qt_right_semi """explain shape plan
+    select * from extract_from_disjunction_in_join_t1 t1 right semi join 
extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || 
t1.a=2 && t2.a=8);"""
+    qt_left """explain shape plan
+    select * from extract_from_disjunction_in_join_t1 t1 left join 
extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || 
t1.a=2 && t2.a=8);"""
+    qt_right """explain shape plan
+    select * from extract_from_disjunction_in_join_t1 t1 right join 
extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || 
t1.a=2 && t2.a=8);"""
+    qt_left_anti """explain shape plan
+    select * from extract_from_disjunction_in_join_t1 t1 left anti join 
extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || 
t1.a=2 && t2.a=8);"""
+    qt_right_anti """explain shape plan
+    select * from extract_from_disjunction_in_join_t1 t1 right anti join 
extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || 
t1.a=2 && t2.a=8);"""
+    qt_inner """explain shape plan
+    select * from extract_from_disjunction_in_join_t1 t1 inner join 
extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || 
t1.a=2 && t2.a=8);"""
+    qt_outer """explain shape plan
+    select * from extract_from_disjunction_in_join_t1 t1 full join 
extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || 
t1.a=2 && t2.a=8)
+    where t1.c=3;"""
+
+    qt_left_semi_res "select t1.a from extract_from_disjunction_in_join_t1 t1 
left semi join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 
&& t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
+    qt_right_semi_res "select t2.a from extract_from_disjunction_in_join_t1 t1 
right semi join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 
&& t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
+    qt_left_res "select t1.a from extract_from_disjunction_in_join_t1 t1 left 
join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 
|| t1.a=2 && t2.a=8) order by 1;"
+    qt_right_res "select t1.a from extract_from_disjunction_in_join_t1 t1 
right join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && 
t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
+    qt_left_anti_res "select t1.a from extract_from_disjunction_in_join_t1 t1 
left anti join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 
&& t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
+    qt_right_anti_res "select t2.a from extract_from_disjunction_in_join_t1 t1 
right anti join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 
&& t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
+    qt_inner_res "select t1.a from extract_from_disjunction_in_join_t1 t1 
inner join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && 
t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
+    qt_outer_res """select t1.a from extract_from_disjunction_in_join_t1 t1 
full join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && 
t1.a=1 || t1.a=2 && t2.a=8)
+    where t1.c=3 order by 1;"""
+}
\ No newline at end of file
diff --git 
a/regression-test/suites/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.groovy
 
b/regression-test/suites/nereids_rules_p0/push_down_filter/push_down_filter_through_window.groovy
similarity index 100%
rename from 
regression-test/suites/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.groovy
rename to 
regression-test/suites/nereids_rules_p0/push_down_filter/push_down_filter_through_window.groovy


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to