This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit ad1c19bd65644bc19392eef6ebc2e845b342fd37
Author: jakevin <[email protected]>
AuthorDate: Mon Jan 22 12:23:50 2024 +0800

    [refactor](Nereids): Eager Aggregation unify pushdown agg function (#30142)
---
 .../rules/rewrite/PushDownMinMaxThroughJoin.java   |  17 ++-
 .../rules/rewrite/PushDownSumThroughJoin.java      |   4 +-
 .../rewrite/PushDownSumThroughJoinOneSide.java     | 117 +--------------------
 3 files changed, 19 insertions(+), 119 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java
index 48ded00defe..3057f1eafc4 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java
@@ -81,7 +81,7 @@ public class PushDownMinMaxThroughJoin implements 
RewriteRuleFactory {
                                 return null;
                             }
                             LogicalAggregate<LogicalJoin<Plan, Plan>> agg = 
ctx.root;
-                            return pushMinMax(agg, agg.child(), 
ImmutableList.of());
+                            return pushMinMaxSum(agg, agg.child(), 
ImmutableList.of());
                         })
                         .toRule(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN),
                 logicalAggregate(logicalProject(innerLogicalJoin()))
@@ -102,13 +102,16 @@ public class PushDownMinMaxThroughJoin implements 
RewriteRuleFactory {
                                 return null;
                             }
                             LogicalAggregate<LogicalProject<LogicalJoin<Plan, 
Plan>>> agg = ctx.root;
-                            return pushMinMax(agg, agg.child().child(), 
agg.child().getProjects());
+                            return pushMinMaxSum(agg, agg.child().child(), 
agg.child().getProjects());
                         })
                         .toRule(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN)
         );
     }
 
-    private LogicalAggregate<Plan> pushMinMax(LogicalAggregate<? extends Plan> 
agg,
+    /**
+     * Push down Min/Max/Sum through join.
+     */
+    public static LogicalAggregate<Plan> pushMinMaxSum(LogicalAggregate<? 
extends Plan> agg,
             LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) {
         List<Slot> leftOutput = join.left().getOutput();
         List<Slot> rightOutput = join.right().getOutput();
@@ -125,6 +128,9 @@ public class PushDownMinMaxThroughJoin implements 
RewriteRuleFactory {
                 throw new IllegalStateException("Slot " + slot + " not found 
in join output");
             }
         }
+        if (leftFuncs.isEmpty() && rightFuncs.isEmpty()) {
+            return null;
+        }
 
         Set<Slot> leftGroupBy = new HashSet<>();
         Set<Slot> rightGroupBy = new HashSet<>();
@@ -177,6 +183,11 @@ public class PushDownMinMaxThroughJoin implements 
RewriteRuleFactory {
         Preconditions.checkState(left != join.left() || right != join.right());
         Plan newJoin = join.withChildren(left, right);
 
+        // top agg
+        // replace
+        // min(x) -> min(min#)
+        // max(x) -> max(max#)
+        // sum(x) -> sum(sum#)
         List<NamedExpression> newOutputExprs = new ArrayList<>();
         for (NamedExpression ne : agg.getOutputExpressions()) {
             if (ne instanceof Alias && ((Alias) ne).child() instanceof 
AggregateFunction) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java
index 91cb2a6050b..e8987e670a5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java
@@ -53,12 +53,12 @@ import java.util.Set;
  * |    *
  * (x)
  * ->
- * aggregate: Sum(min1)
+ * aggregate: Sum(sum1)
  * |
  * join
  * |   \
  * |    *
- * aggregate: Sum(x) as min1
+ * aggregate: Sum(x) as sum1
  * </pre>
  */
 public class PushDownSumThroughJoin implements RewriteRuleFactory {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java
index 3f4fa09cd71..88b13b383a3 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java
@@ -19,9 +19,6 @@ package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
@@ -30,15 +27,9 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 
-import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableList.Builder;
 
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
-import java.util.Map;
 import java.util.Set;
 
 /**
@@ -79,7 +70,7 @@ public class PushDownSumThroughJoinOneSide implements 
RewriteRuleFactory {
                                 return null;
                             }
                             LogicalAggregate<LogicalJoin<Plan, Plan>> agg = 
ctx.root;
-                            return pushSum(agg, agg.child(), 
ImmutableList.of());
+                            return 
PushDownMinMaxThroughJoin.pushMinMaxSum(agg, agg.child(), ImmutableList.of());
                         })
                         .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN),
                 logicalAggregate(logicalProject(innerLogicalJoin()))
@@ -98,112 +89,10 @@ public class PushDownSumThroughJoinOneSide implements 
RewriteRuleFactory {
                                 return null;
                             }
                             LogicalAggregate<LogicalProject<LogicalJoin<Plan, 
Plan>>> agg = ctx.root;
-                            return pushSum(agg, agg.child().child(), 
agg.child().getProjects());
+                            return 
PushDownMinMaxThroughJoin.pushMinMaxSum(agg, agg.child().child(),
+                                    agg.child().getProjects());
                         })
                         .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN)
         );
     }
-
-    private LogicalAggregate<Plan> pushSum(LogicalAggregate<? extends Plan> 
agg,
-            LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) {
-        List<Slot> leftOutput = join.left().getOutput();
-        List<Slot> rightOutput = join.right().getOutput();
-
-        List<Sum> leftSums = new ArrayList<>();
-        List<Sum> rightSums = new ArrayList<>();
-        for (AggregateFunction f : agg.getAggregateFunctions()) {
-            Sum sum = (Sum) f;
-            Slot slot = (Slot) sum.child();
-            if (leftOutput.contains(slot)) {
-                leftSums.add(sum);
-            } else if (rightOutput.contains(slot)) {
-                rightSums.add(sum);
-            } else {
-                throw new IllegalStateException("Slot " + slot + " not found 
in join output");
-            }
-        }
-        if (leftSums.isEmpty() && rightSums.isEmpty()) {
-            return null;
-        }
-
-        Set<Slot> leftGroupBy = new HashSet<>();
-        Set<Slot> rightGroupBy = new HashSet<>();
-        for (Expression e : agg.getGroupByExpressions()) {
-            Slot slot = (Slot) e;
-            if (leftOutput.contains(slot)) {
-                leftGroupBy.add(slot);
-            } else if (rightOutput.contains(slot)) {
-                rightGroupBy.add(slot);
-            } else {
-                return null;
-            }
-        }
-        join.getHashJoinConjuncts().forEach(e -> 
e.getInputSlots().forEach(slot -> {
-            if (leftOutput.contains(slot)) {
-                leftGroupBy.add(slot);
-            } else if (rightOutput.contains(slot)) {
-                rightGroupBy.add(slot);
-            } else {
-                throw new IllegalStateException("Slot " + slot + " not found 
in join output");
-            }
-        }));
-
-        Plan left = join.left();
-        Plan right = join.right();
-
-        Map<Slot, NamedExpression> leftSumSlotToOutput = new HashMap<>();
-        Map<Slot, NamedExpression> rightSumSlotToOutput = new HashMap<>();
-
-        // left Sum agg
-        if (!leftSums.isEmpty()) {
-            Builder<NamedExpression> leftSumAggOutputBuilder = 
ImmutableList.<NamedExpression>builder()
-                    .addAll(leftGroupBy);
-            leftSums.forEach(func -> {
-                Alias alias = func.alias(func.getName());
-                leftSumSlotToOutput.put((Slot) func.child(0), alias);
-                leftSumAggOutputBuilder.add(alias);
-            });
-            left = new LogicalAggregate<>(ImmutableList.copyOf(leftGroupBy), 
leftSumAggOutputBuilder.build(),
-                    join.left());
-        }
-
-        // right Sum agg
-        if (!rightSums.isEmpty()) {
-            Builder<NamedExpression> rightSumAggOutputBuilder = 
ImmutableList.<NamedExpression>builder()
-                    .addAll(rightGroupBy);
-            rightSums.forEach(func -> {
-                Alias alias = func.alias(func.getName());
-                rightSumSlotToOutput.put((Slot) func.child(0), alias);
-                rightSumAggOutputBuilder.add(alias);
-            });
-            right = new LogicalAggregate<>(ImmutableList.copyOf(rightGroupBy), 
rightSumAggOutputBuilder.build(),
-                    join.right());
-        }
-
-        Preconditions.checkState(left != join.left() || right != join.right());
-        Plan newJoin = join.withChildren(left, right);
-
-        // top Sum agg
-        // replace sum(x) -> sum(sum#)
-        List<NamedExpression> newOutputExprs = new ArrayList<>();
-        for (NamedExpression ne : agg.getOutputExpressions()) {
-            if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) {
-                Sum oldTopSum = (Sum) ((Alias) ne).child();
-
-                Slot slot = (Slot) oldTopSum.child(0);
-                if (leftSumSlotToOutput.containsKey(slot)) {
-                    Expression expr = new 
Sum(leftSumSlotToOutput.get(slot).toSlot());
-                    newOutputExprs.add((NamedExpression) 
ne.withChildren(expr));
-                } else if (rightSumSlotToOutput.containsKey(slot)) {
-                    Expression expr = new 
Sum(rightSumSlotToOutput.get(slot).toSlot());
-                    newOutputExprs.add((NamedExpression) 
ne.withChildren(expr));
-                } else {
-                    throw new IllegalStateException("Slot " + slot + " not 
found in join output");
-                }
-            } else {
-                newOutputExprs.add(ne);
-            }
-        }
-        return agg.withAggOutputChild(newOutputExprs, newJoin);
-    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to