This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 32fa9e09f4 [feature](Nereids): enable OuterJoinAssoc (#19111)
32fa9e09f4 is described below

commit 32fa9e09f4167322378a92f2ece95e235f8f7f6a
Author: jakevin <[email protected]>
AuthorDate: Thu Apr 27 09:50:15 2023 +0800

    [feature](Nereids): enable OuterJoinAssoc (#19111)
---
 .../org/apache/doris/nereids/rules/RuleSet.java    |  2 ++
 .../rules/exploration/join/OuterJoinAssoc.java     | 24 +++++++++++++----
 .../exploration/join/OuterJoinAssocProject.java    | 25 +++++++++++++++++-
 .../ExtractAndNormalizeWindowExpression.java       |  3 +--
 .../rules/rewrite/logical/NormalizeAggregate.java  | 30 ++++++++++------------
 .../rules/exploration/join/OuterJoinAssocTest.java | 16 ++++++++++++
 6 files changed, 76 insertions(+), 24 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
index aa30417fc2..b696b94f59 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
@@ -32,6 +32,7 @@ import 
org.apache.doris.nereids.rules.exploration.join.JoinExchange;
 import org.apache.doris.nereids.rules.exploration.join.JoinExchangeBothProject;
 import 
org.apache.doris.nereids.rules.exploration.join.LogicalJoinSemiJoinTranspose;
 import 
org.apache.doris.nereids.rules.exploration.join.LogicalJoinSemiJoinTransposeProject;
+import org.apache.doris.nereids.rules.exploration.join.OuterJoinAssoc;
 import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscom;
 import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscomProject;
 import 
org.apache.doris.nereids.rules.exploration.join.PushdownProjectThroughInnerJoin;
@@ -162,6 +163,7 @@ public class RuleSet {
             .add(InnerJoinRightAssociateProject.INSTANCE)
             .add(JoinExchange.INSTANCE)
             .add(JoinExchangeBothProject.INSTANCE)
+            .add(OuterJoinAssoc.INSTANCE)
             .build();
 
     public List<Rule> getOtherReorderRules() {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssoc.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssoc.java
index 6bb35baa88..2080cfce93 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssoc.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssoc.java
@@ -21,12 +21,14 @@ import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
+import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.plans.GroupPlan;
 import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.nereids.util.Utils;
 
 import com.google.common.collect.ImmutableSet;
@@ -58,17 +60,29 @@ public class OuterJoinAssoc extends 
OneExplorationRuleFactory {
                 .when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, 
topJoin.left()))
                 .when(topJoin -> checkCondition(topJoin, 
topJoin.left().left().getOutputSet()))
                 .whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin())
-                .then(topJoin -> {
+                .thenApply(ctx -> {
+                    LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> 
topJoin = ctx.root;
                     LogicalJoin<GroupPlan, GroupPlan> bottomJoin = 
topJoin.left();
                     GroupPlan a = bottomJoin.left();
                     GroupPlan b = bottomJoin.right();
                     GroupPlan c = topJoin.right();
 
-                    /* TODO:
-                     * p23 need to reject nulls on A(e2) (Eqv. 1)
-                     * see paper `On the Correct and Complete Enumeration of 
the Core Search Space`.
-                     * But because we have added eliminate_outer_rule, we 
don't need to consider this.
+                    /*
+                     * Paper `On the Correct and Complete Enumeration of the 
Core Search Space`.
+                     * p23 need to reject nulls on A(e2) (Eqv. 1).
+                     * It means that when slot is null, condition must return 
false or unknown.
                      */
+                    if (bottomJoin.getJoinType().isLeftOuterJoin() && 
topJoin.getJoinType().isLeftOuterJoin()) {
+                        Set<Slot> conditionSlot = topJoin.getConditionSlot();
+                        Set<Expression> on = ImmutableSet.<Expression>builder()
+                                .addAll(topJoin.getHashJoinConjuncts())
+                                
.addAll(topJoin.getOtherJoinConjuncts()).build();
+                        Set<Slot> notNullSlots = 
ExpressionUtils.inferNotNullSlots(on,
+                                ctx.cascadesContext);
+                        if (!conditionSlot.equals(notNullSlots)) {
+                            return null;
+                        }
+                    }
 
                     LogicalJoin newBottomJoin = 
topJoin.withChildrenNoContext(b, c);
                     
newBottomJoin.getJoinReorderContext().copyFrom(bottomJoin.getJoinReorderContext());
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java
index 308e1db583..efd9e14faf 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java
@@ -23,13 +23,18 @@ import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.rules.exploration.CBOUtils;
 import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
 import org.apache.doris.nereids.trees.expressions.ExprId;
+import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.plans.GroupPlan;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.nereids.util.Utils;
 
+import com.google.common.collect.ImmutableSet;
+
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
@@ -60,7 +65,8 @@ public class OuterJoinAssocProject extends 
OneExplorationRuleFactory {
                 .whenNot(join -> join.isMarkJoin() || 
join.left().child().isMarkJoin())
                 .when(join -> OuterJoinAssoc.checkCondition(join, 
join.left().child().left().getOutputSet()))
                 .when(join -> join.left().isAllSlots())
-                .then(topJoin -> {
+                .thenApply(ctx -> {
+                    LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, 
GroupPlan>>, GroupPlan> topJoin = ctx.root;
                     /* ********** init ********** */
                     List<NamedExpression> projects = 
topJoin.left().getProjects();
                     LogicalJoin<GroupPlan, GroupPlan> bottomJoin = 
topJoin.left().child();
@@ -69,6 +75,23 @@ public class OuterJoinAssocProject extends 
OneExplorationRuleFactory {
                     GroupPlan c = topJoin.right();
                     Set<ExprId> aOutputExprIds = a.getOutputExprIdSet();
 
+                    /*
+                     * Paper `On the Correct and Complete Enumeration of the 
Core Search Space`.
+                     * p23 need to reject nulls on A(e2) (Eqv. 1).
+                     * It means that when slot is null, condition must return 
false or unknown.
+                     */
+                    if (bottomJoin.getJoinType().isLeftOuterJoin() && 
topJoin.getJoinType().isLeftOuterJoin()) {
+                        Set<Slot> conditionSlot = topJoin.getConditionSlot();
+                        Set<Expression> on = ImmutableSet.<Expression>builder()
+                                .addAll(topJoin.getHashJoinConjuncts())
+                                
.addAll(topJoin.getOtherJoinConjuncts()).build();
+                        Set<Slot> notNullSlots = 
ExpressionUtils.inferNotNullSlots(on,
+                                ctx.cascadesContext);
+                        if (!conditionSlot.equals(notNullSlots)) {
+                            return null;
+                        }
+                    }
+
                     /* ********** Split projects ********** */
                     Map<Boolean, List<NamedExpression>> map = 
CBOUtils.splitProject(projects, aOutputExprIds);
                     List<NamedExpression> aProjects = map.get(true);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
index 718d6acc6e..9282ef3825 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
@@ -33,7 +33,6 @@ import org.apache.doris.nereids.util.ExpressionUtils;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 
 import java.util.List;
@@ -85,7 +84,7 @@ public class ExtractAndNormalizeWindowExpression extends 
OneRewriteRuleFactory i
             Set<NamedExpression> normalizedWindowWithAlias = 
ctxForWindows.pushDownToNamedExpression(normalizedWindows);
             // only need normalized windowExpressions
             LogicalWindow normalizedLogicalWindow =
-                    new 
LogicalWindow(Lists.newArrayList(normalizedWindowWithAlias), normalizedChild);
+                    new 
LogicalWindow<>(ImmutableList.copyOf(normalizedWindowWithAlias), 
normalizedChild);
 
             // 3. handle top projects
             List<NamedExpression> topProjects = 
ctxForWindows.normalizeToUseSlotRef(normalizedOutputs1);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
index cf802245ca..7938cf8769 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
@@ -56,35 +56,38 @@ import java.util.stream.Stream;
  *            Alias(SUM(v1#3 + 1))#7, Alias(SUM(v1#3) + 1)#8])
  * </pre>
  * After rule:
+ * <pre>
  * Project(k1#1, Alias(SR#9)#4, Alias(k1#1 + 1)#5, Alias(SR#10))#6, 
Alias(SR#11))#7, Alias(SR#10 + 1)#8)
  * +-- Aggregate(keys:[k1#1, SR#9], outputs:[k1#1, SR#9, Alias(SUM(v1#3))#10, 
Alias(SUM(v1#3 + 1))#11])
  *   +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3)
- * <p>
- *
+ * </pre>
  * Note: window function will be moved to upper project
  * all agg functions except the top agg should be pushed to Aggregate node.
  * example 1:
+ * <pre>
  *    select min(x), sum(x) over () ...
  * the 'sum(x)' is top agg of window function, it should be moved to upper 
project
  * plan:
  *    project(sum(x) over())
  *        Aggregate(min(x), x)
- *
+ * </pre>
  * example 2:
+ * <pre>
  *    select min(x), avg(sum(x)) over() ...
  * the 'sum(x)' should be moved to Aggregate
  * plan:
  *    project(avg(y) over())
  *         Aggregate(min(x), sum(x) as y)
+ * </pre>
  * example 3:
+ * <pre>
  *    select sum(x+1), x+1, sum(x+1) over() ...
  * window function should use x instead of x+1
  * plan:
  *    project(sum(x+1) over())
  *        Agg(sum(y), x)
  *            project(x+1 as y)
- *
- *
+ * </pre>
  * More example could get from UT {NormalizeAggregateTest}
  */
 public class NormalizeAggregate extends OneRewriteRuleFactory implements 
NormalizeToSlot {
@@ -138,8 +141,8 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
             // some expression on the aggregate functions, e.g. `sum(value) + 
1`, we should replace
             // the sum(value) to slot and move the `slot + 1` to the upper 
project later.
             List<NamedExpression> normalizeOutputPhase1 = Stream.concat(
-                    aggregate.getOutputExpressions().stream(),
-                    aliasOfAggFunInWindowUsedAsAggOutput.stream())
+                            aggregate.getOutputExpressions().stream(),
+                            aliasOfAggFunInWindowUsedAsAggOutput.stream())
                     .map(expr -> groupByAndArgumentToSlotContext
                             .normalizeToUseSlotRefUp(expr, 
WindowExpression.class::isInstance))
                     .collect(Collectors.toList());
@@ -198,19 +201,14 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
         Set<AggregateFunction> aggregateFunctions = 
collectNonWindowedAggregateFunctions(
                 aggregate.getOutputExpressions());
 
-        ImmutableSet<Expression> argumentsOfAggregateFunction = 
aggregateFunctions.stream()
-                .flatMap(function -> function.getArguments().stream().map(arg 
-> {
-                    if (arg instanceof OrderExpression) {
-                        return arg.child(0);
-                    } else {
-                        return arg;
-                    }
-                }))
+        Set<Expression> argumentsOfAggregateFunction = 
aggregateFunctions.stream()
+                .flatMap(function -> function.getArguments().stream()
+                        .map(expr -> expr instanceof OrderExpression ? 
expr.child(0) : expr))
                 .collect(ImmutableSet.toImmutableSet());
 
         Set<Expression> windowFunctionKeys = 
collectWindowFunctionKeys(aggregate.getOutputExpressions());
 
-        ImmutableSet<Expression> needPushDown = 
ImmutableSet.<Expression>builder()
+        Set<Expression> needPushDown = ImmutableSet.<Expression>builder()
                 // group by should be pushed down, e.g. group by (k + 1),
                 // we should push down the `k + 1` to the bottom plan
                 .addAll(groupingByExpr)
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java
index ad98b7650f..c3beb8fc11 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.rules.exploration.join;
 
 import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.IsNull;
 import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
 import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
@@ -28,6 +29,8 @@ import org.apache.doris.nereids.util.MemoTestUtils;
 import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
 
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 import java.util.Objects;
@@ -78,4 +81,17 @@ class OuterJoinAssocTest implements 
MemoPatternMatchSupported {
                         ).when(top -> 
Objects.equals(top.getHashJoinConjuncts().toString(), "[(id#0 = id#2)]"))
                 );
     }
+
+    @Test
+    public void rejectNull() {
+        IsNull isNull = new IsNull(scan3.getOutput().get(0));
+        LogicalPlan join = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id 
= t2.id
+                .join(scan3, JoinType.LEFT_OUTER_JOIN, ImmutableList.of(), 
ImmutableList.of(isNull)) // t3.id is not null
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), join)
+                .applyExploration(OuterJoinAssoc.INSTANCE.build())
+                .checkMemo(memo -> Assertions.assertEquals(1, 
memo.getRoot().getLogicalExpressions().size()));
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to