englefly commented on code in PR #59116: URL: https://github.com/apache/doris/pull/59116#discussion_r2689193825
########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java: ########## @@ -0,0 +1,513 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext; +import org.apache.doris.nereids.trees.copier.DeepCopierContext; +import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +/** + * This rule will rewrite grouping sets. eg: + * select a, b, c, d, e sum(f) from t group by rollup(a, b, c, d, e); + * rewrite to: + * with cte1 as (select a, b, c, d, e, sum(f) x from t group by rollup(a, b, c, d, e)) + * select * fom cte1 + * union all + * select a, b, c, d, null, sum(x) x from t group by rollup(a, b, c, d) + * + * LogicalAggregate(gby: a,b,c,d,e,grouping_id output:a,b,c,d,e,grouping_id,sum(f)) + * +--LogicalRepeat(grouping sets: (a,b,c,d,e),(a,b,c,d),(a,b,c),(a,b),(a),()) + * -> + * LogicalCTEAnchor + * +--LogicalCTEProducer(cte) + * +--LogicalAggregate(gby: a,b,c,d,e; aggFunc: sum(f) as x) + * +--LogicalUnionAll + * +--LogicalProject(a,b,c,d, null as e, sum(x)) + * +--LogicalAggregate(gby:a,b,c,d,grouping_id; aggFunc: sum(x)) + * +--LogicalRepeat(grouping sets: (a,b,c,d),(a,b,c),(a,b),(a),()) + * +--LogicalCTEConsumer(aggregateConsumer) + * +--LogicalCTEConsumer(directConsumer) + */ +public class DecomposeRepeatWithPreAggregation extends DefaultPlanRewriter<DistinctSelectorContext> + implements CustomRewriter { + public static final DecomposeRepeatWithPreAggregation INSTANCE = new DecomposeRepeatWithPreAggregation(); + private static final Set<Class<? extends AggregateFunction>> SUPPORT_AGG_FUNCTIONS = + ImmutableSet.of(Sum.class, Sum0.class, Min.class, Max.class, AnyValue.class, Count.class); + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + DistinctSelectorContext ctx = new DistinctSelectorContext(jobContext.getCascadesContext().getStatementContext(), + jobContext.getCascadesContext()); + plan = plan.accept(this, ctx); + for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) { + LogicalCTEProducer<? extends Plan> producer = ctx.cteProducerList.get(i); + plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan); + } + return plan; + } + + @Override + public Plan visitLogicalCTEAnchor( + LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, DistinctSelectorContext ctx) { + Plan child1 = anchor.child(0).accept(this, ctx); + DistinctSelectorContext consumerContext = + new DistinctSelectorContext(ctx.statementContext, ctx.cascadesContext); + Plan child2 = anchor.child(1).accept(this, consumerContext); + for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) { + LogicalCTEProducer<? extends Plan> producer = consumerContext.cteProducerList.get(i); + child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, child2); + } + return anchor.withChildren(ImmutableList.of(child1, child2)); + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, DistinctSelectorContext ctx) { + aggregate = visitChildren(this, aggregate, ctx); + int maxGroupIndex = canOptimize(aggregate); + if (maxGroupIndex < 0) { + return aggregate; + } + Map<Slot, Slot> preToProducerSlotMap = new HashMap<>(); + LogicalCTEProducer<LogicalAggregate<Plan>> producer = constructProducer(aggregate, maxGroupIndex, ctx, + preToProducerSlotMap); + LogicalCTEConsumer aggregateConsumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(), + producer.getCteId(), "", producer); + LogicalCTEConsumer directConsumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(), + producer.getCteId(), "", producer); + + // build map : origin slot to consumer slot + Map<Slot, Slot> producerToConsumerMap = new HashMap<>(); + for (Map.Entry<Slot, Slot> entry : aggregateConsumer.getProducerToConsumerOutputMap().entries()) { + producerToConsumerMap.put(entry.getKey(), entry.getValue()); + } + Map<Slot, Slot> originToConsumerMap = new HashMap<>(); + for (Map.Entry<Slot, Slot> entry : preToProducerSlotMap.entrySet()) { + originToConsumerMap.put(entry.getKey(), producerToConsumerMap.get(entry.getValue())); + } + + LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child(); + List<List<Expression>> newGroupingSets = new ArrayList<>(); + for (int i = 0; i < repeat.getGroupingSets().size(); ++i) { + if (i == maxGroupIndex) { + continue; + } + newGroupingSets.add(repeat.getGroupingSets().get(i)); + } + List<NamedExpression> groupingFunctionSlots = new ArrayList<>(); + LogicalRepeat<Plan> newRepeat = constructRepeat(repeat, aggregateConsumer, newGroupingSets, + originToConsumerMap, groupingFunctionSlots); + Set<Expression> needRemovedExprSet = getNeedAddNullExpressions(repeat, newGroupingSets, maxGroupIndex); + Map<AggregateFunction, Slot> aggFuncToSlot = new HashMap<>(); + LogicalAggregate<Plan> topAgg = constructAgg(aggregate, originToConsumerMap, newRepeat, groupingFunctionSlots, + aggFuncToSlot); + LogicalProject<Plan> project = constructProject(aggregate, originToConsumerMap, needRemovedExprSet, + groupingFunctionSlots, topAgg, aggFuncToSlot); + LogicalPlan directChild = getDirectChild(directConsumer, groupingFunctionSlots); + return constructUnion(project, directChild, aggregate); + } + + /** + * Get the direct child plan for the union operation. + * If there are grouping function slots, wrap the consumer with a project that adds + * zero literals for each grouping function slot to match the output schema. + * + * @param directConsumer the CTE consumer for the direct path + * @param groupingFunctionSlots the list of grouping function slots to handle + * @return the direct child plan, possibly wrapped with a project + */ + private LogicalPlan getDirectChild(LogicalCTEConsumer directConsumer, List<NamedExpression> groupingFunctionSlots) { + LogicalPlan directChild = directConsumer; + if (!groupingFunctionSlots.isEmpty()) { + ImmutableList.Builder<NamedExpression> builder = ImmutableList.builder(); + builder.addAll(directConsumer.getOutput()); + for (int i = 0; i < groupingFunctionSlots.size(); ++i) { + builder.add(new Alias(new BigIntLiteral(0))); + } + directChild = new LogicalProject<Plan>(builder.build(), directConsumer); + } + return directChild; + } + + /** + * Build a map from aggregate function to its corresponding slot. + * + * @param outputExpressions the output expressions containing aggregate functions + * @param pToc the map from producer slot to consumer slot + * @return a map from aggregate function to its corresponding slot in consumer outputs + */ + private Map<AggregateFunction, Slot> getAggFuncSlotMap(List<NamedExpression> outputExpressions, + Map<Slot, Slot> pToc) { + // build map : aggFunc to Slot + Map<AggregateFunction, Slot> aggFuncSlotMap = new HashMap<>(); + for (NamedExpression expr : outputExpressions) { + if (expr instanceof Alias) { + Optional<Expression> aggFunc = expr.child(0).collectFirst(e -> e instanceof AggregateFunction); + aggFunc.ifPresent( + func -> aggFuncSlotMap.put((AggregateFunction) func, pToc.get(expr.toSlot()))); + } + } + return aggFuncSlotMap; + } + + /** + * Get the set of expressions that need to be replaced with null in the new grouping sets. + * These are expressions that exist in the maximum grouping set but not in other grouping sets. + * + * @param repeat the original LogicalRepeat plan + * @param newGroupingSets the new grouping sets after removing the maximum grouping set + * @param maxGroupIndex the index of the maximum grouping set + * @return the set of expressions that need to be replaced with null + */ + private Set<Expression> getNeedAddNullExpressions(LogicalRepeat<Plan> repeat, + List<List<Expression>> newGroupingSets, int maxGroupIndex) { + Set<Expression> otherGroupExprSet = new HashSet<>(); + for (List<Expression> groupingSet : newGroupingSets) { + otherGroupExprSet.addAll(groupingSet); + } + List<Expression> maxGroupByList = repeat.getGroupingSets().get(maxGroupIndex); + Set<Expression> needRemovedExprSet = new HashSet<>(maxGroupByList); + needRemovedExprSet.removeAll(otherGroupExprSet); + return needRemovedExprSet; + } + + /** + * Construct a LogicalAggregate for the decomposed repeat. + * + * @param aggregate the original aggregate plan + * @param originToConsumerMap the map from original slots to consumer slots + * @param newRepeat the new LogicalRepeat plan with reduced grouping sets + * @param groupingFunctionSlots the list of new grouping function slots + * @param aggFuncToSlot output parameter: map from original aggregate functions to their slots in the new aggregate + * @return a LogicalAggregate for the decomposed repeat + */ + private LogicalAggregate<Plan> constructAgg(LogicalAggregate<? extends Plan> aggregate, + Map<Slot, Slot> originToConsumerMap, LogicalRepeat<Plan> newRepeat, + List<NamedExpression> groupingFunctionSlots, Map<AggregateFunction, Slot> aggFuncToSlot) { + Map<AggregateFunction, Slot> aggFuncSlotMap = getAggFuncSlotMap(aggregate.getOutputExpressions(), + originToConsumerMap); + Set<Slot> groupingSetsUsedSlot = ImmutableSet.copyOf( + ExpressionUtils.flatExpressions((List) newRepeat.getGroupingSets())); + List<Expression> topAggGby = new ArrayList<>(groupingSetsUsedSlot); + topAggGby.add(newRepeat.getGroupingId().get()); + topAggGby.addAll(groupingFunctionSlots); + List<NamedExpression> topAggOutput = new ArrayList<>((List) topAggGby); + for (NamedExpression expr : aggregate.getOutputExpressions()) { + if (expr instanceof Alias && expr.containsType(AggregateFunction.class)) { + NamedExpression aggFuncAfterRewrite = (NamedExpression) expr.rewriteDownShortCircuit(e -> { + if (e instanceof AggregateFunction) { + if (e instanceof Count) { + return new Sum(aggFuncSlotMap.get(e)); + } else { + return e.withChildren(aggFuncSlotMap.get(e)); + } + } else { + return e; + } + }); + aggFuncAfterRewrite = ((Alias) aggFuncAfterRewrite) + .withExprId(StatementScopeIdGenerator.newExprId()); + NamedExpression replacedExpr = (NamedExpression) aggFuncAfterRewrite.rewriteDownShortCircuit( + e -> { + if (originToConsumerMap.containsKey(e)) { + return originToConsumerMap.get(e); + } else { + return e; + } + } + ); + topAggOutput.add(replacedExpr); + aggFuncToSlot.put((AggregateFunction) expr.collectFirst(e -> e instanceof AggregateFunction).get(), + replacedExpr.toSlot()); + } + } + return new LogicalAggregate<>(topAggGby, topAggOutput, Optional.of(newRepeat), newRepeat); + } + + /** + * Construct a LogicalProject that wraps the aggregate and handles output expressions. + * This method replaces removed expressions with null literals, and output the grouping scalar functions + * at the end of the projections. + * + * @param aggregate the original aggregate plan + * @param originToConsumerMap the map from original slots to consumer slots + * @param needRemovedExprSet the set of expressions that need to be replaced with null + * @param groupingFunctionSlots the list of grouping function slots to add to the project + * @param topAgg the aggregate plan to wrap + * @param aggFuncToSlot the map from aggregate functions to their slots + * @return a LogicalProject wrapping the aggregate with proper output expressions + */ + private LogicalProject<Plan> constructProject(LogicalAggregate<? extends Plan> aggregate, + Map<Slot, Slot> originToConsumerMap, Set<Expression> needRemovedExprSet, + List<NamedExpression> groupingFunctionSlots, LogicalAggregate<Plan> topAgg, + Map<AggregateFunction, Slot> aggFuncToSlot) { + LogicalRepeat<?> repeat = (LogicalRepeat<?>) aggregate.child(0); + Set<ExprId> originGroupingFunctionId = new HashSet<>(); + for (NamedExpression namedExpression : repeat.getGroupingScalarFunctionAlias()) { + originGroupingFunctionId.add(namedExpression.getExprId()); + } + ImmutableList.Builder<NamedExpression> projects = ImmutableList.builder(); + for (NamedExpression expr : aggregate.getOutputExpressions()) { + if (needRemovedExprSet.contains(expr)) { + projects.add(new Alias(new NullLiteral(expr.getDataType()), expr.getName())); + } else if (expr instanceof Alias && expr.containsType(AggregateFunction.class)) { + AggregateFunction aggregateFunction = (AggregateFunction) expr.collectFirst( + e -> e instanceof AggregateFunction).get(); + projects.add(aggFuncToSlot.get(aggregateFunction)); + } else if (expr.getExprId().equals(repeat.getGroupingId().get().getExprId()) + || originGroupingFunctionId.contains(expr.getExprId())) { + continue; + } else { + NamedExpression replacedExpr = (NamedExpression) expr.rewriteDownShortCircuit( + e -> { + if (originToConsumerMap.containsKey(e)) { + return originToConsumerMap.get(e); + } else { + return e; + } + } + ); + projects.add(replacedExpr.toSlot()); + } + } + projects.addAll(groupingFunctionSlots); + return new LogicalProject<>(projects.build(), topAgg); + } + + /** + * Construct a LogicalUnion that combines the results from the decomposed repeat + * and the CTE consumer. + * + * @param aggregateProject the first child plan (project with aggregate) + * @param directConsumer the second child plan (CTE consumer) + * @param aggregate the original aggregate plan for output reference + * @return a LogicalUnion combining the two children + */ + private LogicalUnion constructUnion(LogicalPlan aggregateProject, LogicalPlan directConsumer, + LogicalAggregate<? extends Plan> aggregate) { + LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child(); + List<NamedExpression> unionOutputs = new ArrayList<>(); + List<List<SlotReference>> childrenOutputs = new ArrayList<>(); + childrenOutputs.add((List) aggregateProject.getOutput()); + childrenOutputs.add((List) directConsumer.getOutput()); + Set<ExprId> groupingFunctionId = new HashSet<>(); + for (NamedExpression alias : repeat.getGroupingScalarFunctionAlias()) { + groupingFunctionId.add(alias.getExprId()); + } + List<NamedExpression> groupingFunctionSlots = new ArrayList<>(); + for (NamedExpression expr : aggregate.getOutputExpressions()) { + if (expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) { + continue; + } + if (groupingFunctionId.contains(expr.getExprId())) { + groupingFunctionSlots.add(expr.toSlot()); + continue; + } + unionOutputs.add(expr.toSlot()); + } + unionOutputs.addAll(groupingFunctionSlots); + return new LogicalUnion(Qualifier.ALL, unionOutputs, childrenOutputs, ImmutableList.of(), + false, ImmutableList.of(aggregateProject, directConsumer)); + } + + /** + * Determine if optimization is possible; if so, return the index of the largest group. + * The optimization requires: + * 1. The aggregate's child must be a LogicalRepeat + * 2. All aggregate functions must be Sum, Min, or Max (non-distinct) + * 3. No GroupingScalarFunction in repeat output + * 4. More than 3 grouping sets Review Comment: 3 grouping sets 的情况也可以有优化吧 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
