feiniaofeiafei commented on code in PR #59116:
URL: https://github.com/apache/doris/pull/59116#discussion_r2689590885


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java:
##########
@@ -0,0 +1,513 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ExprId;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * This rule will rewrite grouping sets. eg:
+ *      select a, b, c, d, e sum(f) from t group by rollup(a, b, c, d, e);
+ * rewrite to:
+ *    with cte1 as (select a, b, c, d, e, sum(f) x from t group by rollup(a, 
b, c, d, e))
+ *      select * fom cte1
+ *      union all
+ *      select a, b, c, d, null, sum(x) x from t group by rollup(a, b, c, d)
+ *
+ * LogicalAggregate(gby: a,b,c,d,e,grouping_id 
output:a,b,c,d,e,grouping_id,sum(f))
+ *   +--LogicalRepeat(grouping sets: 
(a,b,c,d,e),(a,b,c,d),(a,b,c),(a,b),(a),())
+ * ->
+ * LogicalCTEAnchor
+ *   +--LogicalCTEProducer(cte)
+ *     +--LogicalAggregate(gby: a,b,c,d,e; aggFunc: sum(f) as x)
+ *   +--LogicalUnionAll
+ *     +--LogicalProject(a,b,c,d, null as e, sum(x))
+ *       +--LogicalAggregate(gby:a,b,c,d,grouping_id; aggFunc: sum(x))
+ *         +--LogicalRepeat(grouping sets: (a,b,c,d),(a,b,c),(a,b),(a),())
+ *           +--LogicalCTEConsumer(aggregateConsumer)
+ *     +--LogicalCTEConsumer(directConsumer)
+ */
+public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<DistinctSelectorContext>
+        implements CustomRewriter {
+    public static final DecomposeRepeatWithPreAggregation INSTANCE = new 
DecomposeRepeatWithPreAggregation();
+    private static final Set<Class<? extends AggregateFunction>> 
SUPPORT_AGG_FUNCTIONS =
+            ImmutableSet.of(Sum.class, Sum0.class, Min.class, Max.class, 
AnyValue.class, Count.class);
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        DistinctSelectorContext ctx = new 
DistinctSelectorContext(jobContext.getCascadesContext().getStatementContext(),
+                jobContext.getCascadesContext());
+        plan = plan.accept(this, ctx);
+        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
ctx.cteProducerList.get(i);
+            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
+        }
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalCTEAnchor(
+            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, 
DistinctSelectorContext ctx) {
+        Plan child1 = anchor.child(0).accept(this, ctx);
+        DistinctSelectorContext consumerContext =
+                new DistinctSelectorContext(ctx.statementContext, 
ctx.cascadesContext);
+        Plan child2 = anchor.child(1).accept(this, consumerContext);
+        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
consumerContext.cteProducerList.get(i);
+            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, 
child2);
+        }
+        return anchor.withChildren(ImmutableList.of(child1, child2));
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> 
aggregate, DistinctSelectorContext ctx) {
+        aggregate = visitChildren(this, aggregate, ctx);
+        int maxGroupIndex = canOptimize(aggregate);
+        if (maxGroupIndex < 0) {
+            return aggregate;
+        }
+        Map<Slot, Slot> preToProducerSlotMap = new HashMap<>();
+        LogicalCTEProducer<LogicalAggregate<Plan>> producer = 
constructProducer(aggregate, maxGroupIndex, ctx,
+                preToProducerSlotMap);
+        LogicalCTEConsumer aggregateConsumer = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+        LogicalCTEConsumer directConsumer = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+
+        // build map : origin slot to consumer slot
+        Map<Slot, Slot> producerToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : 
aggregateConsumer.getProducerToConsumerOutputMap().entries()) {
+            producerToConsumerMap.put(entry.getKey(), entry.getValue());
+        }
+        Map<Slot, Slot> originToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : preToProducerSlotMap.entrySet()) {
+            originToConsumerMap.put(entry.getKey(), 
producerToConsumerMap.get(entry.getValue()));
+        }
+
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<List<Expression>> newGroupingSets = new ArrayList<>();
+        for (int i = 0; i < repeat.getGroupingSets().size(); ++i) {
+            if (i == maxGroupIndex) {
+                continue;
+            }
+            newGroupingSets.add(repeat.getGroupingSets().get(i));
+        }
+        List<NamedExpression> groupingFunctionSlots = new ArrayList<>();
+        LogicalRepeat<Plan> newRepeat = constructRepeat(repeat, 
aggregateConsumer, newGroupingSets,
+                originToConsumerMap, groupingFunctionSlots);
+        Set<Expression> needRemovedExprSet = getNeedAddNullExpressions(repeat, 
newGroupingSets, maxGroupIndex);
+        Map<AggregateFunction, Slot> aggFuncToSlot = new HashMap<>();
+        LogicalAggregate<Plan> topAgg = constructAgg(aggregate, 
originToConsumerMap, newRepeat, groupingFunctionSlots,
+                aggFuncToSlot);
+        LogicalProject<Plan> project = constructProject(aggregate, 
originToConsumerMap, needRemovedExprSet,
+                groupingFunctionSlots, topAgg, aggFuncToSlot);
+        LogicalPlan directChild = getDirectChild(directConsumer, 
groupingFunctionSlots);
+        return constructUnion(project, directChild, aggregate);
+    }
+
+    /**
+     * Get the direct child plan for the union operation.
+     * If there are grouping function slots, wrap the consumer with a project 
that adds
+     * zero literals for each grouping function slot to match the output 
schema.
+     *
+     * @param directConsumer the CTE consumer for the direct path
+     * @param groupingFunctionSlots the list of grouping function slots to 
handle
+     * @return the direct child plan, possibly wrapped with a project
+     */
+    private LogicalPlan getDirectChild(LogicalCTEConsumer directConsumer, 
List<NamedExpression> groupingFunctionSlots) {
+        LogicalPlan directChild = directConsumer;
+        if (!groupingFunctionSlots.isEmpty()) {
+            ImmutableList.Builder<NamedExpression> builder = 
ImmutableList.builder();
+            builder.addAll(directConsumer.getOutput());
+            for (int i = 0; i < groupingFunctionSlots.size(); ++i) {
+                builder.add(new Alias(new BigIntLiteral(0)));
+            }
+            directChild = new LogicalProject<Plan>(builder.build(), 
directConsumer);
+        }
+        return directChild;
+    }
+
+    /**
+     * Build a map from aggregate function to its corresponding slot.
+     *
+     * @param outputExpressions the output expressions containing aggregate 
functions
+     * @param pToc the map from producer slot to consumer slot
+     * @return a map from aggregate function to its corresponding slot in 
consumer outputs
+     */
+    private Map<AggregateFunction, Slot> 
getAggFuncSlotMap(List<NamedExpression> outputExpressions,
+            Map<Slot, Slot> pToc) {
+        // build map : aggFunc to Slot
+        Map<AggregateFunction, Slot> aggFuncSlotMap = new HashMap<>();
+        for (NamedExpression expr : outputExpressions) {
+            if (expr instanceof Alias) {
+                Optional<Expression> aggFunc = expr.child(0).collectFirst(e -> 
e instanceof AggregateFunction);
+                aggFunc.ifPresent(
+                        func -> aggFuncSlotMap.put((AggregateFunction) func, 
pToc.get(expr.toSlot())));
+            }
+        }
+        return aggFuncSlotMap;
+    }
+
+    /**
+     * Get the set of expressions that need to be replaced with null in the 
new grouping sets.
+     * These are expressions that exist in the maximum grouping set but not in 
other grouping sets.
+     *
+     * @param repeat the original LogicalRepeat plan
+     * @param newGroupingSets the new grouping sets after removing the maximum 
grouping set
+     * @param maxGroupIndex the index of the maximum grouping set
+     * @return the set of expressions that need to be replaced with null
+     */
+    private Set<Expression> getNeedAddNullExpressions(LogicalRepeat<Plan> 
repeat,
+            List<List<Expression>> newGroupingSets, int maxGroupIndex) {
+        Set<Expression> otherGroupExprSet = new HashSet<>();
+        for (List<Expression> groupingSet : newGroupingSets) {
+            otherGroupExprSet.addAll(groupingSet);
+        }
+        List<Expression> maxGroupByList = 
repeat.getGroupingSets().get(maxGroupIndex);
+        Set<Expression> needRemovedExprSet = new HashSet<>(maxGroupByList);
+        needRemovedExprSet.removeAll(otherGroupExprSet);
+        return needRemovedExprSet;
+    }
+
+    /**
+     * Construct a LogicalAggregate for the decomposed repeat.
+     *
+     * @param aggregate the original aggregate plan
+     * @param originToConsumerMap the map from original slots to consumer slots
+     * @param newRepeat the new LogicalRepeat plan with reduced grouping sets
+     * @param groupingFunctionSlots the list of new grouping function slots
+     * @param aggFuncToSlot output parameter: map from original aggregate 
functions to their slots in the new aggregate
+     * @return a LogicalAggregate for the decomposed repeat
+     */
+    private LogicalAggregate<Plan> constructAgg(LogicalAggregate<? extends 
Plan> aggregate,
+            Map<Slot, Slot> originToConsumerMap, LogicalRepeat<Plan> newRepeat,
+            List<NamedExpression> groupingFunctionSlots, 
Map<AggregateFunction, Slot> aggFuncToSlot) {
+        Map<AggregateFunction, Slot> aggFuncSlotMap = 
getAggFuncSlotMap(aggregate.getOutputExpressions(),
+                originToConsumerMap);
+        Set<Slot> groupingSetsUsedSlot = ImmutableSet.copyOf(
+                ExpressionUtils.flatExpressions((List) 
newRepeat.getGroupingSets()));
+        List<Expression> topAggGby = new ArrayList<>(groupingSetsUsedSlot);
+        topAggGby.add(newRepeat.getGroupingId().get());
+        topAggGby.addAll(groupingFunctionSlots);
+        List<NamedExpression> topAggOutput = new ArrayList<>((List) topAggGby);
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if (expr instanceof Alias && 
expr.containsType(AggregateFunction.class)) {
+                NamedExpression aggFuncAfterRewrite = (NamedExpression) 
expr.rewriteDownShortCircuit(e -> {
+                    if (e instanceof AggregateFunction) {
+                        if (e instanceof Count) {
+                            return new Sum(aggFuncSlotMap.get(e));
+                        } else {
+                            return e.withChildren(aggFuncSlotMap.get(e));
+                        }
+                    } else {
+                        return e;
+                    }
+                });
+                aggFuncAfterRewrite = ((Alias) aggFuncAfterRewrite)
+                        .withExprId(StatementScopeIdGenerator.newExprId());
+                NamedExpression replacedExpr = (NamedExpression) 
aggFuncAfterRewrite.rewriteDownShortCircuit(
+                        e -> {
+                            if (originToConsumerMap.containsKey(e)) {
+                                return originToConsumerMap.get(e);
+                            } else {
+                                return e;
+                            }
+                        }
+                );
+                topAggOutput.add(replacedExpr);
+                aggFuncToSlot.put((AggregateFunction) expr.collectFirst(e -> e 
instanceof AggregateFunction).get(),
+                        replacedExpr.toSlot());
+            }
+        }
+        return new LogicalAggregate<>(topAggGby, topAggOutput, 
Optional.of(newRepeat), newRepeat);
+    }
+
+    /**
+     * Construct a LogicalProject that wraps the aggregate and handles output 
expressions.
+     * This method replaces removed expressions with null literals, and output 
the grouping scalar functions
+     * at the end of the projections.
+     *
+     * @param aggregate the original aggregate plan
+     * @param originToConsumerMap the map from original slots to consumer slots
+     * @param needRemovedExprSet the set of expressions that need to be 
replaced with null
+     * @param groupingFunctionSlots the list of grouping function slots to add 
to the project
+     * @param topAgg the aggregate plan to wrap
+     * @param aggFuncToSlot the map from aggregate functions to their slots
+     * @return a LogicalProject wrapping the aggregate with proper output 
expressions
+     */
+    private LogicalProject<Plan> constructProject(LogicalAggregate<? extends 
Plan> aggregate,
+            Map<Slot, Slot> originToConsumerMap, Set<Expression> 
needRemovedExprSet,
+            List<NamedExpression> groupingFunctionSlots, 
LogicalAggregate<Plan> topAgg,
+            Map<AggregateFunction, Slot> aggFuncToSlot) {
+        LogicalRepeat<?> repeat = (LogicalRepeat<?>) aggregate.child(0);
+        Set<ExprId> originGroupingFunctionId = new HashSet<>();
+        for (NamedExpression namedExpression : 
repeat.getGroupingScalarFunctionAlias()) {
+            originGroupingFunctionId.add(namedExpression.getExprId());
+        }
+        ImmutableList.Builder<NamedExpression> projects = 
ImmutableList.builder();
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if (needRemovedExprSet.contains(expr)) {
+                projects.add(new Alias(new NullLiteral(expr.getDataType()), 
expr.getName()));
+            } else if (expr instanceof Alias && 
expr.containsType(AggregateFunction.class)) {
+                AggregateFunction aggregateFunction = (AggregateFunction) 
expr.collectFirst(
+                        e -> e instanceof AggregateFunction).get();
+                projects.add(aggFuncToSlot.get(aggregateFunction));
+            } else if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())
+                    || originGroupingFunctionId.contains(expr.getExprId())) {
+                continue;
+            } else {
+                NamedExpression replacedExpr = (NamedExpression) 
expr.rewriteDownShortCircuit(
+                        e -> {
+                            if (originToConsumerMap.containsKey(e)) {
+                                return originToConsumerMap.get(e);
+                            } else {
+                                return e;
+                            }
+                        }
+                );
+                projects.add(replacedExpr.toSlot());
+            }
+        }
+        projects.addAll(groupingFunctionSlots);
+        return new LogicalProject<>(projects.build(), topAgg);
+    }
+
+    /**
+     * Construct a LogicalUnion that combines the results from the decomposed 
repeat
+     * and the CTE consumer.
+     *
+     * @param aggregateProject the first child plan (project with aggregate)
+     * @param directConsumer the second child plan (CTE consumer)
+     * @param aggregate the original aggregate plan for output reference
+     * @return a LogicalUnion combining the two children
+     */
+    private LogicalUnion constructUnion(LogicalPlan aggregateProject, 
LogicalPlan directConsumer,
+            LogicalAggregate<? extends Plan> aggregate) {
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<NamedExpression> unionOutputs = new ArrayList<>();
+        List<List<SlotReference>> childrenOutputs = new ArrayList<>();
+        childrenOutputs.add((List) aggregateProject.getOutput());
+        childrenOutputs.add((List) directConsumer.getOutput());
+        Set<ExprId> groupingFunctionId = new HashSet<>();
+        for (NamedExpression alias : repeat.getGroupingScalarFunctionAlias()) {
+            groupingFunctionId.add(alias.getExprId());
+        }
+        List<NamedExpression> groupingFunctionSlots = new ArrayList<>();
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) {
+                continue;
+            }
+            if (groupingFunctionId.contains(expr.getExprId())) {
+                groupingFunctionSlots.add(expr.toSlot());
+                continue;
+            }
+            unionOutputs.add(expr.toSlot());
+        }
+        unionOutputs.addAll(groupingFunctionSlots);
+        return new LogicalUnion(Qualifier.ALL, unionOutputs, childrenOutputs, 
ImmutableList.of(),
+                false, ImmutableList.of(aggregateProject, directConsumer));
+    }
+
+    /**
+     * Determine if optimization is possible; if so, return the index of the 
largest group.
+     * The optimization requires:
+     * 1. The aggregate's child must be a LogicalRepeat
+     * 2. All aggregate functions must be Sum, Min, or Max (non-distinct)
+     * 3. No GroupingScalarFunction in repeat output
+     * 4. More than 3 grouping sets

Review Comment:
   目前就先拍了一个3,大于3才能优化



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to