This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit ad1c19bd65644bc19392eef6ebc2e845b342fd37 Author: jakevin <[email protected]> AuthorDate: Mon Jan 22 12:23:50 2024 +0800 [refactor](Nereids): Eager Aggregation unify pushdown agg function (#30142) --- .../rules/rewrite/PushDownMinMaxThroughJoin.java | 17 ++- .../rules/rewrite/PushDownSumThroughJoin.java | 4 +- .../rewrite/PushDownSumThroughJoinOneSide.java | 117 +-------------------- 3 files changed, 19 insertions(+), 119 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java index 48ded00defe..3057f1eafc4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java @@ -81,7 +81,7 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory { return null; } LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root; - return pushMinMax(agg, agg.child(), ImmutableList.of()); + return pushMinMaxSum(agg, agg.child(), ImmutableList.of()); }) .toRule(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN), logicalAggregate(logicalProject(innerLogicalJoin())) @@ -102,13 +102,16 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory { return null; } LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root; - return pushMinMax(agg, agg.child().child(), agg.child().getProjects()); + return pushMinMaxSum(agg, agg.child().child(), agg.child().getProjects()); }) .toRule(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN) ); } - private LogicalAggregate<Plan> pushMinMax(LogicalAggregate<? extends Plan> agg, + /** + * Push down Min/Max/Sum through join. + */ + public static LogicalAggregate<Plan> pushMinMaxSum(LogicalAggregate<? extends Plan> agg, LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) { List<Slot> leftOutput = join.left().getOutput(); List<Slot> rightOutput = join.right().getOutput(); @@ -125,6 +128,9 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory { throw new IllegalStateException("Slot " + slot + " not found in join output"); } } + if (leftFuncs.isEmpty() && rightFuncs.isEmpty()) { + return null; + } Set<Slot> leftGroupBy = new HashSet<>(); Set<Slot> rightGroupBy = new HashSet<>(); @@ -177,6 +183,11 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory { Preconditions.checkState(left != join.left() || right != join.right()); Plan newJoin = join.withChildren(left, right); + // top agg + // replace + // min(x) -> min(min#) + // max(x) -> max(max#) + // sum(x) -> sum(sum#) List<NamedExpression> newOutputExprs = new ArrayList<>(); for (NamedExpression ne : agg.getOutputExpressions()) { if (ne instanceof Alias && ((Alias) ne).child() instanceof AggregateFunction) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java index 91cb2a6050b..e8987e670a5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java @@ -53,12 +53,12 @@ import java.util.Set; * | * * (x) * -> - * aggregate: Sum(min1) + * aggregate: Sum(sum1) * | * join * | \ * | * - * aggregate: Sum(x) as min1 + * aggregate: Sum(x) as sum1 * </pre> */ public class PushDownSumThroughJoin implements RewriteRuleFactory { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java index 3f4fa09cd71..88b13b383a3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java @@ -19,9 +19,6 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; @@ -30,15 +27,9 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableList.Builder; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Set; /** @@ -79,7 +70,7 @@ public class PushDownSumThroughJoinOneSide implements RewriteRuleFactory { return null; } LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root; - return pushSum(agg, agg.child(), ImmutableList.of()); + return PushDownMinMaxThroughJoin.pushMinMaxSum(agg, agg.child(), ImmutableList.of()); }) .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN), logicalAggregate(logicalProject(innerLogicalJoin())) @@ -98,112 +89,10 @@ public class PushDownSumThroughJoinOneSide implements RewriteRuleFactory { return null; } LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root; - return pushSum(agg, agg.child().child(), agg.child().getProjects()); + return PushDownMinMaxThroughJoin.pushMinMaxSum(agg, agg.child().child(), + agg.child().getProjects()); }) .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN) ); } - - private LogicalAggregate<Plan> pushSum(LogicalAggregate<? extends Plan> agg, - LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) { - List<Slot> leftOutput = join.left().getOutput(); - List<Slot> rightOutput = join.right().getOutput(); - - List<Sum> leftSums = new ArrayList<>(); - List<Sum> rightSums = new ArrayList<>(); - for (AggregateFunction f : agg.getAggregateFunctions()) { - Sum sum = (Sum) f; - Slot slot = (Slot) sum.child(); - if (leftOutput.contains(slot)) { - leftSums.add(sum); - } else if (rightOutput.contains(slot)) { - rightSums.add(sum); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - } - if (leftSums.isEmpty() && rightSums.isEmpty()) { - return null; - } - - Set<Slot> leftGroupBy = new HashSet<>(); - Set<Slot> rightGroupBy = new HashSet<>(); - for (Expression e : agg.getGroupByExpressions()) { - Slot slot = (Slot) e; - if (leftOutput.contains(slot)) { - leftGroupBy.add(slot); - } else if (rightOutput.contains(slot)) { - rightGroupBy.add(slot); - } else { - return null; - } - } - join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { - if (leftOutput.contains(slot)) { - leftGroupBy.add(slot); - } else if (rightOutput.contains(slot)) { - rightGroupBy.add(slot); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - })); - - Plan left = join.left(); - Plan right = join.right(); - - Map<Slot, NamedExpression> leftSumSlotToOutput = new HashMap<>(); - Map<Slot, NamedExpression> rightSumSlotToOutput = new HashMap<>(); - - // left Sum agg - if (!leftSums.isEmpty()) { - Builder<NamedExpression> leftSumAggOutputBuilder = ImmutableList.<NamedExpression>builder() - .addAll(leftGroupBy); - leftSums.forEach(func -> { - Alias alias = func.alias(func.getName()); - leftSumSlotToOutput.put((Slot) func.child(0), alias); - leftSumAggOutputBuilder.add(alias); - }); - left = new LogicalAggregate<>(ImmutableList.copyOf(leftGroupBy), leftSumAggOutputBuilder.build(), - join.left()); - } - - // right Sum agg - if (!rightSums.isEmpty()) { - Builder<NamedExpression> rightSumAggOutputBuilder = ImmutableList.<NamedExpression>builder() - .addAll(rightGroupBy); - rightSums.forEach(func -> { - Alias alias = func.alias(func.getName()); - rightSumSlotToOutput.put((Slot) func.child(0), alias); - rightSumAggOutputBuilder.add(alias); - }); - right = new LogicalAggregate<>(ImmutableList.copyOf(rightGroupBy), rightSumAggOutputBuilder.build(), - join.right()); - } - - Preconditions.checkState(left != join.left() || right != join.right()); - Plan newJoin = join.withChildren(left, right); - - // top Sum agg - // replace sum(x) -> sum(sum#) - List<NamedExpression> newOutputExprs = new ArrayList<>(); - for (NamedExpression ne : agg.getOutputExpressions()) { - if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) { - Sum oldTopSum = (Sum) ((Alias) ne).child(); - - Slot slot = (Slot) oldTopSum.child(0); - if (leftSumSlotToOutput.containsKey(slot)) { - Expression expr = new Sum(leftSumSlotToOutput.get(slot).toSlot()); - newOutputExprs.add((NamedExpression) ne.withChildren(expr)); - } else if (rightSumSlotToOutput.containsKey(slot)) { - Expression expr = new Sum(rightSumSlotToOutput.get(slot).toSlot()); - newOutputExprs.add((NamedExpression) ne.withChildren(expr)); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - } else { - newOutputExprs.add(ne); - } - } - return agg.withAggOutputChild(newOutputExprs, newJoin); - } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
