This is an automated email from the ASF dual-hosted git repository.

zclll pushed a commit to branch tpc_preview6
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/tpc_preview6 by this push:
     new 448c8ac2425 pick decomposeRepeat fix bug and choose one shuffle key in 
decomposeRepeat (#60385)
448c8ac2425 is described below

commit 448c8ac2425b1e81a2bd7e3c158c4cac28c239c6
Author: feiniaofeiafei <[email protected]>
AuthorDate: Fri Jan 30 12:18:54 2026 +0800

    pick decomposeRepeat fix bug and choose one shuffle key in decomposeRepeat 
(#60385)
---
 .../glue/translator/PhysicalPlanTranslator.java    |   2 +-
 .../doris/nereids/parser/LogicalPlanBuilder.java   |   9 +-
 .../mv/AbstractMaterializedViewAggregateRule.java  |   3 +-
 .../implementation/SplitAggWithoutDistinct.java    |   5 +-
 .../rewrite/DecomposeRepeatWithPreAggregation.java | 148 +++++++++++++--
 .../trees/copier/LogicalPlanDeepCopier.java        |  17 +-
 .../doris/nereids/trees/plans/algebra/Repeat.java  |  60 ++++--
 .../trees/plans/logical/LogicalAggregate.java      |  80 +++++---
 .../nereids/trees/plans/logical/LogicalRepeat.java |  35 ++--
 .../java/org/apache/doris/qe/SessionVariable.java  |   9 +
 .../rules/analysis/NormalizeRepeatTest.java        |   4 +
 .../DecomposeRepeatWithPreAggregationTest.java     |   3 +
 .../PushDownFilterThroughAggregationTest.java      |   5 +-
 .../trees/copier/LogicalPlanDeepCopierTest.java    |   2 +
 .../nereids/trees/plans/algebra/RepeatTest.java    | 206 +++++++++++++++++++++
 .../nereids_p0/repeat/test_repeat_output_slot.out  |  11 +-
 .../decompose_repeat/decompose_repeat.out          | 184 ++++++++++++++++++
 .../decompose_repeat/decompose_repeat.groovy       |  32 ++++
 18 files changed, 729 insertions(+), 86 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index 81903aa0ec6..9956cbe318d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -2567,7 +2567,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         // cube and rollup already convert to grouping sets in 
LogicalPlanBuilder.withAggregate()
         GroupingInfo groupingInfo = new GroupingInfo(outputTuple, 
preRepeatExprs);
 
-        List<Set<Integer>> repeatSlotIdList = 
repeat.computeRepeatSlotIdList(getSlotIds(outputTuple));
+        List<Set<Integer>> repeatSlotIdList = 
repeat.computeRepeatSlotIdList(getSlotIds(outputTuple), outputSlots);
         Set<Integer> allSlotId = repeatSlotIdList.stream()
                 .flatMap(Set::stream)
                 .collect(ImmutableSet.toImmutableSet());
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
index 5c48122262c..3aedce70308 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
@@ -609,6 +609,7 @@ import org.apache.doris.nereids.trees.plans.PlanType;
 import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
 import org.apache.doris.nereids.trees.plans.algebra.InlineTable;
 import org.apache.doris.nereids.trees.plans.algebra.OneRowRelation;
+import org.apache.doris.nereids.trees.plans.algebra.Repeat.RepeatType;
 import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
 import org.apache.doris.nereids.trees.plans.commands.AddConstraintCommand;
 import 
org.apache.doris.nereids.trees.plans.commands.AdminCancelRebalanceDiskCommand;
@@ -4765,15 +4766,15 @@ public class LogicalPlanBuilder extends 
DorisParserBaseVisitor<Object> {
                 for (GroupingSetContext groupingSetContext : 
groupingElementContext.groupingSet()) {
                     groupingSets.add(visit(groupingSetContext.expression(), 
Expression.class));
                 }
-                return new LogicalRepeat<>(groupingSets.build(), 
namedExpressions, input);
+                return new LogicalRepeat<>(groupingSets.build(), 
namedExpressions, RepeatType.GROUPING_SETS, input);
             } else if (groupingElementContext.CUBE() != null) {
                 List<Expression> cubeExpressions = 
visit(groupingElementContext.expression(), Expression.class);
                 List<List<Expression>> groupingSets = 
ExpressionUtils.cubeToGroupingSets(cubeExpressions);
-                return new LogicalRepeat<>(groupingSets, namedExpressions, 
input);
+                return new LogicalRepeat<>(groupingSets, namedExpressions, 
RepeatType.CUBE, input);
             } else if (groupingElementContext.ROLLUP() != null && 
groupingElementContext.WITH() == null) {
                 List<Expression> rollupExpressions = 
visit(groupingElementContext.expression(), Expression.class);
                 List<List<Expression>> groupingSets = 
ExpressionUtils.rollupToGroupingSets(rollupExpressions);
-                return new LogicalRepeat<>(groupingSets, namedExpressions, 
input);
+                return new LogicalRepeat<>(groupingSets, namedExpressions, 
RepeatType.ROLLUP, input);
             } else {
                 List<GroupKeyWithOrder> groupKeyWithOrders = 
visit(groupingElementContext.expressionWithOrder(),
                         GroupKeyWithOrder.class);
@@ -4787,7 +4788,7 @@ public class LogicalPlanBuilder extends 
DorisParserBaseVisitor<Object> {
                 }
                 if (groupingElementContext.ROLLUP() != null) {
                     List<List<Expression>> groupingSets = 
ExpressionUtils.rollupToGroupingSets(groupByExpressions);
-                    return new LogicalRepeat<>(groupingSets, namedExpressions, 
input);
+                    return new LogicalRepeat<>(groupingSets, namedExpressions, 
RepeatType.ROLLUP, input);
                 } else {
                     return new LogicalAggregate<>(groupByExpressions, 
namedExpressions, input);
                 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java
index 87ddc7a0ca5..211f2578721 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java
@@ -264,7 +264,8 @@ public abstract class AbstractMaterializedViewAggregateRule 
extends AbstractMate
                 }
             }
             LogicalRepeat<Plan> repeat = new 
LogicalRepeat<>(rewrittenGroupSetsExpressions,
-                    finalOutputExpressions, 
queryStructInfo.getGroupingId().get(), tempRewritedPlan);
+                    finalOutputExpressions, 
queryStructInfo.getGroupingId().get(),
+                    queryAggregate.getSourceRepeat().get().getRepeatType(), 
tempRewritedPlan);
             return NormalizeRepeat.doNormalize(repeat);
         }
         return new LogicalAggregate<>(finalGroupExpressions, 
finalOutputExpressions, tempRewritedPlan);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/SplitAggWithoutDistinct.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/SplitAggWithoutDistinct.java
index ed94fa730ca..de9526005d0 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/SplitAggWithoutDistinct.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/SplitAggWithoutDistinct.java
@@ -96,7 +96,8 @@ public class SplitAggWithoutDistinct extends 
OneImplementationRuleFactory {
                 }
         );
         AggregateParam param = new AggregateParam(AggPhase.GLOBAL, 
AggMode.INPUT_TO_RESULT, !skipRegulator(logicalAgg));
-        return ImmutableList.of(new 
PhysicalHashAggregate<>(logicalAgg.getGroupByExpressions(), aggOutput, param,
+        return ImmutableList.of(new 
PhysicalHashAggregate<>(logicalAgg.getGroupByExpressions(), aggOutput,
+                logicalAgg.getPartitionExpressions(), param,
                 
AggregateUtils.maybeUsingStreamAgg(logicalAgg.getGroupByExpressions(), param),
                 null, logicalAgg.child()));
     }
@@ -159,7 +160,7 @@ public class SplitAggWithoutDistinct extends 
OneImplementationRuleFactory {
                     return new AggregateExpression(aggFunc, 
bufferToResultParam, alias.toSlot());
                 });
         return ImmutableList.of(new 
PhysicalHashAggregate<>(aggregate.getGroupByExpressions(),
-                globalAggOutput, bufferToResultParam,
+                globalAggOutput, aggregate.getPartitionExpressions(), 
bufferToResultParam,
                 
AggregateUtils.maybeUsingStreamAgg(aggregate.getGroupByExpressions(), 
bufferToResultParam),
                 aggregate.getLogicalProperties(), localAgg));
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java
index 6f6a7f373a5..3939d79a510 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.nereids.jobs.JobContext;
 import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.rules.rewrite.StatsDerive.DeriveContext;
 import org.apache.doris.nereids.trees.copier.DeepCopierContext;
 import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
 import org.apache.doris.nereids.trees.expressions.Alias;
@@ -50,17 +51,22 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
 import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
 import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
 import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.qe.ConnectContext;
+import org.apache.doris.statistics.ColumnStatistic;
+import org.apache.doris.statistics.Statistics;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 
 import java.util.ArrayList;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
+import java.util.TreeMap;
 
 /**
  * This rule will rewrite grouping sets. eg:
@@ -119,13 +125,13 @@ public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<Disti
     @Override
     public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> 
aggregate, DistinctSelectorContext ctx) {
         aggregate = visitChildren(this, aggregate, ctx);
-        int maxGroupIndex = canOptimize(aggregate);
+        int maxGroupIndex = canOptimize(aggregate, 
ctx.cascadesContext.getConnectContext());
         if (maxGroupIndex < 0) {
             return aggregate;
         }
         Map<Slot, Slot> preToProducerSlotMap = new HashMap<>();
         LogicalCTEProducer<LogicalAggregate<Plan>> producer = 
constructProducer(aggregate, maxGroupIndex, ctx,
-                preToProducerSlotMap);
+                preToProducerSlotMap, ctx.cascadesContext.getConnectContext());
         LogicalCTEConsumer aggregateConsumer = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
                 producer.getCteId(), "", producer);
         LogicalCTEConsumer directConsumer = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
@@ -276,6 +282,8 @@ public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<Disti
                         replacedExpr.toSlot());
             }
         }
+        // NOTE: shuffle key selection is applied on the pre-agg (producer) 
side by setting
+        // LogicalAggregate.partitionExpressions. See constructProducer().
         return new LogicalAggregate<>(topAggGby, topAggOutput, 
Optional.of(newRepeat), newRepeat);
     }
 
@@ -369,16 +377,14 @@ public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<Disti
      * Determine if optimization is possible; if so, return the index of the 
largest group.
      * The optimization requires:
      * 1. The aggregate's child must be a LogicalRepeat
-     * 2. All aggregate functions must be Sum, Min, or Max (non-distinct)
-     * 3. No GroupingScalarFunction in repeat output
-     * 4. More than 3 grouping sets
-     * 5. There exists a grouping set that contains all other grouping sets
-     *
+     * 2. All aggregate functions must be in SUPPORT_AGG_FUNCTIONS.
+     * 3. More than 3 grouping sets
+     * 4. There exists a grouping set that contains all other grouping sets
      * @param aggregate the aggregate plan to check
      * @return value -1 means can not be optimized, values other than -1
      *      represent the index of the set that contains all other sets
      */
-    private int canOptimize(LogicalAggregate<? extends Plan> aggregate) {
+    private int canOptimize(LogicalAggregate<? extends Plan> aggregate, 
ConnectContext connectContext) {
         Plan aggChild = aggregate.child();
         if (!(aggChild instanceof LogicalRepeat)) {
             return -1;
@@ -398,10 +404,14 @@ public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<Disti
         // This is an empirical threshold: when there are too few grouping 
sets,
         // the overhead of creating CTE and union may outweigh the benefits.
         // The value 3 is chosen heuristically based on practical experience.
-        if (groupingSets.size() <= 3) {
+        if (groupingSets.size() <= 
connectContext.getSessionVariable().decomposeRepeatThreshold) {
+            return -1;
+        }
+        int maxGroupIndex = findMaxGroupingSetIndex(groupingSets);
+        if (maxGroupIndex < 0) {
             return -1;
         }
-        return findMaxGroupingSetIndex(groupingSets);
+        return maxGroupIndex;
     }
 
     /**
@@ -450,7 +460,8 @@ public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<Disti
      * @return a LogicalCTEProducer containing the pre-aggregation
      */
     private LogicalCTEProducer<LogicalAggregate<Plan>> 
constructProducer(LogicalAggregate<? extends Plan> aggregate,
-            int maxGroupIndex, DistinctSelectorContext ctx, Map<Slot, Slot> 
preToCloneSlotMap) {
+            int maxGroupIndex, DistinctSelectorContext ctx, Map<Slot, Slot> 
preToCloneSlotMap,
+            ConnectContext connectContext) {
         LogicalRepeat<? extends Plan> repeat = (LogicalRepeat<? extends Plan>) 
aggregate.child();
         List<Expression> maxGroupByList = 
repeat.getGroupingSets().get(maxGroupIndex);
         List<NamedExpression> originAggOutputs = 
aggregate.getOutputExpressions();
@@ -469,6 +480,11 @@ public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<Disti
         }
 
         LogicalAggregate<Plan> preAgg = new LogicalAggregate<>(maxGroupByList, 
orderedAggOutputs, repeat.child());
+        Optional<List<Expression>> partitionExprs = 
choosePreAggShuffleKeyPartitionExprs(
+                repeat, maxGroupIndex, maxGroupByList, connectContext);
+        if (partitionExprs.isPresent() && !partitionExprs.get().isEmpty()) {
+            preAgg = preAgg.withPartitionExpressions(partitionExprs);
+        }
         LogicalAggregate<Plan> preAggClone = (LogicalAggregate<Plan>) 
LogicalPlanDeepCopier.INSTANCE
                 .deepCopy(preAgg, new DeepCopierContext());
         for (int i = 0; i < preAgg.getOutputExpressions().size(); ++i) {
@@ -480,6 +496,116 @@ public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<Disti
         return producer;
     }
 
+    /**
+     * Choose partition expressions (shuffle key) for pre-aggregation 
(producer agg).
+     */
+    private Optional<List<Expression>> choosePreAggShuffleKeyPartitionExprs(
+            LogicalRepeat<? extends Plan> repeat, int maxGroupIndex, 
List<Expression> maxGroupByList,
+            ConnectContext connectContext) {
+        int idx = 
connectContext.getSessionVariable().decomposeRepeatShuffleIndexInMaxGroup;
+        if (idx >= 0 && idx < maxGroupByList.size()) {
+            return Optional.of(ImmutableList.of(maxGroupByList.get(idx)));
+        }
+        if (repeat.child().getStats() == null) {
+            repeat.child().accept(new StatsDerive(false), new DeriveContext());
+        }
+        Statistics inputStats = repeat.child().getStats();
+        if (inputStats == null) {
+            return Optional.empty();
+        }
+        int beNumber = Math.max(1, 
connectContext.getEnv().getClusterInfo().getBackendsNumber(true));
+        int parallelInstance = Math.max(1, 
connectContext.getSessionVariable().getParallelExecInstanceNum());
+        int totalInstanceNum = beNumber * parallelInstance;
+        Optional<Expression> chosen;
+        switch (repeat.getRepeatType()) {
+            case CUBE:
+                // Prefer larger NDV to improve balance
+                chosen = chooseByNdv(maxGroupByList, inputStats, 
totalInstanceNum);
+                break;
+            case GROUPING_SETS:
+                chosen = chooseByAppearanceThenNdv(repeat.getGroupingSets(), 
maxGroupIndex, maxGroupByList,
+                        inputStats, totalInstanceNum);
+                break;
+            case ROLLUP:
+                chosen = chooseByRollupPrefixThenNdv(maxGroupByList, 
inputStats, totalInstanceNum);
+                break;
+            default:
+                chosen = Optional.empty();
+        }
+        return chosen.map(ImmutableList::of);
+    }
+
+    private Optional<Expression> chooseByNdv(List<Expression> candidates, 
Statistics inputStats, int totalInstanceNum) {
+        if (inputStats == null) {
+            return Optional.empty();
+        }
+        Comparator<Expression> cmp = Comparator.comparingDouble(e -> 
estimateNdv(e, inputStats));
+        Optional<Expression> choose = candidates.stream().max(cmp);
+        if (choose.isPresent() && estimateNdv(choose.get(), inputStats) > 
totalInstanceNum) {
+            return choose;
+        } else {
+            return Optional.empty();
+        }
+    }
+
+    /**
+     * GROUPING_SETS: prefer keys appearing in more (non-max) grouping sets, 
tie-break by larger NDV.
+     */
+    private Optional<Expression> 
chooseByAppearanceThenNdv(List<List<Expression>> groupingSets, int 
maxGroupIndex,
+            List<Expression> candidates, Statistics inputStats, int 
totalInstanceNum) {
+        Map<Expression, Integer> appearCount = new HashMap<>();
+        for (Expression c : candidates) {
+            appearCount.put(c, 0);
+        }
+        for (int i = 0; i < groupingSets.size(); i++) {
+            if (i == maxGroupIndex) {
+                continue;
+            }
+            List<Expression> set = groupingSets.get(i);
+            for (Expression c : candidates) {
+                if (set.contains(c)) {
+                    appearCount.put(c, appearCount.get(c) + 1);
+                }
+            }
+        }
+        Map<Integer, List<Expression>> countToCandidate = new TreeMap<>();
+        for (Map.Entry<Expression, Integer> entry : appearCount.entrySet()) {
+            countToCandidate.computeIfAbsent(entry.getValue(), v -> new 
ArrayList<>()).add(entry.getKey());
+        }
+        for (Map.Entry<Integer, List<Expression>> entry : 
countToCandidate.entrySet()) {
+            Optional<Expression> chosen = chooseByNdv(entry.getValue(), 
inputStats, totalInstanceNum);
+            if (chosen.isPresent()) {
+                return chosen;
+            }
+        }
+        return Optional.empty();
+
+    }
+
+    /**
+     * ROLLUP: prefer earliest prefix key; if NDV is too low, fallback to next 
prefix.
+     */
+    private Optional<Expression> chooseByRollupPrefixThenNdv(List<Expression> 
candidates, Statistics inputStats,
+            int totalInstanceNum) {
+        for (Expression c : candidates) {
+            if (estimateNdv(c, inputStats) >= totalInstanceNum) {
+                return Optional.of(c);
+            }
+        }
+        return Optional.empty();
+    }
+
+    private double estimateNdv(Expression expr, Statistics stats) {
+        if (stats == null) {
+            return -1D;
+        }
+        ColumnStatistic col = stats.findColumnStatistics(expr);
+        if (col == null || col.isUnKnown()) {
+            return -1D;
+        }
+        return col.ndv;
+    }
+
     /**
      * Construct a new LogicalRepeat with reduced grouping sets and replaced 
expressions.
      * The grouping sets and output expressions are replaced using the slot 
mapping from producer to consumer.
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java
index c969863493c..7fe2dd26ae0 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java
@@ -193,9 +193,18 @@ public class LogicalPlanDeepCopier extends 
DefaultPlanRewriter<DeepCopierContext
                 outputExpressions, child);
         Optional<LogicalRepeat<?>> childRepeat =
                 copiedAggregate.collectFirst(LogicalRepeat.class::isInstance);
-        return childRepeat.isPresent() ? 
aggregate.withChildGroupByAndOutputAndSourceRepeat(
-                groupByExpressions, outputExpressions, child, childRepeat)
-                : aggregate.withChildGroupByAndOutput(groupByExpressions, 
outputExpressions, child);
+        List<Expression> partitionExpressions = ImmutableList.of();
+        if (aggregate.getPartitionExpressions().isPresent()) {
+            partitionExpressions = 
aggregate.getPartitionExpressions().get().stream()
+                    .map(k -> ExpressionDeepCopier.INSTANCE.deepCopy(k, 
context))
+                    .collect(ImmutableList.toImmutableList());
+        }
+        Optional<List<Expression>> optionalPartitionExpressions = 
partitionExpressions.isEmpty()
+                ? Optional.empty() : Optional.of(partitionExpressions);
+        return childRepeat.isPresent() ? 
aggregate.withChildGroupByAndOutputAndSourceRepeatAndPartitionExpr(
+                groupByExpressions, outputExpressions, 
optionalPartitionExpressions, child, childRepeat)
+                : 
aggregate.withChildGroupByAndOutputAndPartitionExpr(groupByExpressions, 
outputExpressions,
+                        optionalPartitionExpressions, child);
     }
 
     @Override
@@ -211,7 +220,7 @@ public class LogicalPlanDeepCopier extends 
DefaultPlanRewriter<DeepCopierContext
                 .collect(ImmutableList.toImmutableList());
         SlotReference groupingId = (SlotReference) 
ExpressionDeepCopier.INSTANCE
                 .deepCopy(repeat.getGroupingId().get(), context);
-        return new LogicalRepeat<>(groupingSets, outputExpressions, 
groupingId, child);
+        return new LogicalRepeat<>(groupingSets, outputExpressions, 
groupingId, repeat.getRepeatType(), child);
     }
 
     @Override
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java
index e35b48073b5..7a7f1f1f3c2 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java
@@ -18,16 +18,15 @@
 package org.apache.doris.nereids.trees.plans.algebra;
 
 import org.apache.doris.nereids.exceptions.AnalysisException;
-import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
 import 
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.util.BitUtils;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
@@ -117,14 +116,35 @@ public interface Repeat<CHILD_PLAN extends Plan> extends 
Aggregate<CHILD_PLAN> {
 
     /**
      * flatten the grouping sets and build to a GroupingSetShapes.
+     * This method ensures that all expressions referenced by grouping 
functions are included
+     * in the flattenGroupingSetExpression, even if they are not in any 
grouping set.
+     * This is necessary for optimization scenarios where some expressions may 
only exist
+     * in the maximum grouping set that was removed during optimization.
      */
     default GroupingSetShapes toShapes() {
-        Set<Expression> flattenGroupingSet = 
ImmutableSet.copyOf(ExpressionUtils.flatExpressions(getGroupingSets()));
+        // Collect all expressions referenced by grouping functions to ensure 
they are included
+        // in flattenGroupingSetExpression, even if they are not in any 
grouping set.
+        // This maintains semantic constraints while allowing optimization.
+        List<GroupingScalarFunction> groupingFunctions = 
ExpressionUtils.collectToList(
+                getOutputExpressions(), 
GroupingScalarFunction.class::isInstance);
+        Set<Expression> groupingFunctionArgs = Sets.newLinkedHashSet();
+        for (GroupingScalarFunction function : groupingFunctions) {
+            groupingFunctionArgs.addAll(function.getArguments());
+        }
+        // Merge grouping set expressions with grouping function arguments
+        // Use LinkedHashSet to preserve order: grouping sets first, then 
grouping function args
+        Set<Expression> flattenGroupingSet = 
Sets.newLinkedHashSet(getGroupByExpressions());
+        for (Expression arg : groupingFunctionArgs) {
+            if (!flattenGroupingSet.contains(arg)) {
+                flattenGroupingSet.add(arg);
+            }
+        }
         List<GroupingSetShape> shapes = Lists.newArrayList();
         for (List<Expression> groupingSet : getGroupingSets()) {
             List<Boolean> shouldBeErasedToNull = 
Lists.newArrayListWithCapacity(flattenGroupingSet.size());
-            for (Expression groupingSetExpression : flattenGroupingSet) {
-                
shouldBeErasedToNull.add(!groupingSet.contains(groupingSetExpression));
+            for (Expression expression : flattenGroupingSet) {
+                // If expression is not in the current grouping set, it should 
be erased to null
+                shouldBeErasedToNull.add(!groupingSet.contains(expression));
             }
             shapes.add(new GroupingSetShape(shouldBeErasedToNull));
         }
@@ -140,8 +160,8 @@ public interface Repeat<CHILD_PLAN extends Plan> extends 
Aggregate<CHILD_PLAN> {
      *
      * return: [(4, 3), (3)]
      */
-    default List<Set<Integer>> computeRepeatSlotIdList(List<Integer> 
slotIdList) {
-        List<Set<Integer>> groupingSetsIndexesInOutput = 
getGroupingSetsIndexesInOutput();
+    default List<Set<Integer>> computeRepeatSlotIdList(List<Integer> 
slotIdList, List<Slot> outputSlots) {
+        List<Set<Integer>> groupingSetsIndexesInOutput = 
getGroupingSetsIndexesInOutput(outputSlots);
         List<Set<Integer>> repeatSlotIdList = Lists.newArrayList();
         for (Set<Integer> groupingSetIndex : groupingSetsIndexesInOutput) {
             // keep order
@@ -160,8 +180,8 @@ public interface Repeat<CHILD_PLAN extends Plan> extends 
Aggregate<CHILD_PLAN> {
      * e.g. groupingSets=((b, a), (a)), output=[a, b]
      * return ((1, 0), (1))
      */
-    default List<Set<Integer>> getGroupingSetsIndexesInOutput() {
-        Map<Expression, Integer> indexMap = indexesOfOutput();
+    default List<Set<Integer>> getGroupingSetsIndexesInOutput(List<Slot> 
outputSlots) {
+        Map<Expression, Integer> indexMap = indexesOfOutput(outputSlots);
 
         List<Set<Integer>> groupingSetsIndex = Lists.newArrayList();
         List<List<Expression>> groupingSets = getGroupingSets();
@@ -184,23 +204,22 @@ public interface Repeat<CHILD_PLAN extends Plan> extends 
Aggregate<CHILD_PLAN> {
     /**
      * indexesOfOutput: get the indexes which mapping from the expression to 
the index in the output.
      *
-     * e.g. output=[a + 1, b + 2, c]
+     * e.g. outputSlots=[a + 1, b + 2, c]
      *
      * return the map(
      *   `a + 1`: 0,
      *   `b + 2`: 1,
      *   `c`: 2
      * )
+     *
+     * Use outputSlots in physicalPlanTranslator instead of 
getOutputExpressions() in this method,
+     * because the outputSlots have same order with slotIdList.
      */
-    default Map<Expression, Integer> indexesOfOutput() {
+    static Map<Expression, Integer> indexesOfOutput(List<Slot> outputSlots) {
         Map<Expression, Integer> indexes = Maps.newLinkedHashMap();
-        List<NamedExpression> outputs = getOutputExpressions();
-        for (int i = 0; i < outputs.size(); i++) {
-            NamedExpression output = outputs.get(i);
+        for (int i = 0; i < outputSlots.size(); i++) {
+            NamedExpression output = outputSlots.get(i);
             indexes.put(output, i);
-            if (output instanceof Alias) {
-                indexes.put(((Alias) output).child(), i);
-            }
         }
         return indexes;
     }
@@ -302,4 +321,11 @@ public interface Repeat<CHILD_PLAN extends Plan> extends 
Aggregate<CHILD_PLAN> {
             return "GroupingSetShape(shouldBeErasedToNull=" + 
shouldBeErasedToNull + ")";
         }
     }
+
+    /** RepeatType */
+    enum RepeatType {
+        ROLLUP,
+        CUBE,
+        GROUPING_SETS
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
index d063ccb40aa..b7b4e4f756b 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
@@ -79,6 +79,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
     private final boolean generated;
     private final boolean hasPushed;
     private final boolean withInProjection;
+    private final Optional<List<Expression>> partitionExpressions;
 
     /**
      * Desc: Constructor for LogicalAggregate.
@@ -97,19 +98,19 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
     public LogicalAggregate(List<NamedExpression> namedExpressions, boolean 
generated, CHILD_TYPE child) {
         this(ImmutableList.copyOf(namedExpressions), namedExpressions,
                 false, true, generated, false, true, Optional.empty(),
-                Optional.empty(), Optional.empty(), child);
+                Optional.empty(), Optional.empty(), Optional.empty(), child);
     }
 
     public LogicalAggregate(List<NamedExpression> namedExpressions, boolean 
generated, boolean hasPushed,
             CHILD_TYPE child) {
         this(ImmutableList.copyOf(namedExpressions), namedExpressions, false, 
true, generated, hasPushed, true,
-                Optional.empty(), Optional.empty(), Optional.empty(), child);
+                Optional.empty(), Optional.empty(), Optional.empty(), 
Optional.empty(), child);
     }
 
     public LogicalAggregate(List<Expression> groupByExpressions,
             List<NamedExpression> outputExpressions, boolean 
ordinalIsResolved, CHILD_TYPE child) {
         this(groupByExpressions, outputExpressions, false, ordinalIsResolved, 
false, false, true, Optional.empty(),
-                Optional.empty(), Optional.empty(), child);
+                Optional.empty(), Optional.empty(), Optional.empty(), child);
     }
 
     /**
@@ -131,7 +132,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
             Optional<LogicalRepeat<?>> sourceRepeat,
             CHILD_TYPE child) {
         this(groupByExpressions, outputExpressions, normalized, false, false, 
false, true, sourceRepeat,
-                Optional.empty(), Optional.empty(), child);
+                Optional.empty(), Optional.empty(), Optional.empty(), child);
     }
 
     /**
@@ -148,6 +149,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
             Optional<LogicalRepeat<?>> sourceRepeat,
             Optional<GroupExpression> groupExpression,
             Optional<LogicalProperties> logicalProperties,
+            Optional<List<Expression>> partitionExpressions,
             CHILD_TYPE child) {
         super(PlanType.LOGICAL_AGGREGATE, groupExpression, logicalProperties, 
child);
         this.groupByExpressions = ImmutableList.copyOf(groupByExpressions);
@@ -162,6 +164,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
         this.hasPushed = hasPushed;
         this.sourceRepeat = Objects.requireNonNull(sourceRepeat, "sourceRepeat 
cannot be null");
         this.withInProjection = withInProjection;
+        this.partitionExpressions = partitionExpressions;
     }
 
     @Override
@@ -280,6 +283,16 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
                 .build();
     }
 
+    public Optional<List<Expression>> getPartitionExpressions() {
+        return partitionExpressions;
+    }
+
+    public LogicalAggregate<Plan> 
withPartitionExpressions(Optional<List<Expression>> newPartitionExpressions) {
+        return new LogicalAggregate<>(groupByExpressions, outputExpressions, 
normalized, ordinalIsResolved, generated,
+                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), newPartitionExpressions,
+                child());
+    }
+
     public boolean isNormalized() {
         return normalized;
     }
@@ -304,26 +317,29 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
                 && normalized == that.normalized
                 && ordinalIsResolved == that.ordinalIsResolved
                 && generated == that.generated
-                && Objects.equals(sourceRepeat, that.sourceRepeat);
+                && Objects.equals(sourceRepeat, that.sourceRepeat)
+                && Objects.equals(partitionExpressions, 
that.partitionExpressions);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(groupByExpressions, outputExpressions, normalized, 
ordinalIsResolved, sourceRepeat);
+        return Objects.hash(groupByExpressions, outputExpressions, normalized, 
ordinalIsResolved, sourceRepeat,
+                partitionExpressions);
     }
 
     @Override
     public LogicalAggregate<Plan> withChildren(List<Plan> children) {
         Preconditions.checkArgument(children.size() == 1);
         return new LogicalAggregate<>(groupByExpressions, outputExpressions, 
normalized, ordinalIsResolved, generated,
-                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), children.get(0));
+                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), partitionExpressions,
+                children.get(0));
     }
 
     @Override
     public LogicalAggregate<Plan> 
withGroupExpression(Optional<GroupExpression> groupExpression) {
         return new LogicalAggregate<>(groupByExpressions, outputExpressions, 
normalized, ordinalIsResolved, generated,
-                hasPushed, withInProjection,
-                sourceRepeat, groupExpression, 
Optional.of(getLogicalProperties()), children.get(0));
+                hasPushed, withInProjection, sourceRepeat, groupExpression, 
Optional.of(getLogicalProperties()),
+                partitionExpressions, children.get(0));
     }
 
     @Override
@@ -331,39 +347,52 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
             Optional<LogicalProperties> logicalProperties, List<Plan> 
children) {
         Preconditions.checkArgument(children.size() == 1);
         return new LogicalAggregate<>(groupByExpressions, outputExpressions, 
normalized, ordinalIsResolved, generated,
-                hasPushed, withInProjection,
-                sourceRepeat, groupExpression, 
Optional.of(getLogicalProperties()), children.get(0));
+                hasPushed, withInProjection, sourceRepeat, groupExpression, 
Optional.of(getLogicalProperties()),
+                partitionExpressions, children.get(0));
     }
 
     public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression> 
groupByExprList,
             List<NamedExpression> outputExpressionList) {
         return new LogicalAggregate<>(groupByExprList, outputExpressionList, 
normalized, ordinalIsResolved, generated,
-                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), child());
+                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), partitionExpressions,
+                child());
     }
 
     public LogicalAggregate<Plan> withGroupBy(List<Expression> 
groupByExprList) {
         return new LogicalAggregate<>(groupByExprList, outputExpressions, 
normalized, ordinalIsResolved, generated,
-                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), child());
+                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), partitionExpressions,
+                child());
     }
 
     public LogicalAggregate<Plan> withChildGroupByAndOutput(List<Expression> 
groupByExprList,
             List<NamedExpression> outputExpressionList, Plan newChild) {
         return new LogicalAggregate<>(groupByExprList, outputExpressionList, 
normalized, ordinalIsResolved, generated,
-                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), newChild);
+                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), partitionExpressions,
+                newChild);
+    }
+
+    public LogicalAggregate<Plan> 
withChildGroupByAndOutputAndPartitionExpr(List<Expression> groupByExprList,
+            List<NamedExpression> outputExpressionList, 
Optional<List<Expression>> partitionExpressions,
+            Plan newChild) {
+        return new LogicalAggregate<>(groupByExprList, outputExpressionList, 
normalized, ordinalIsResolved, generated,
+                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(),
+                partitionExpressions, newChild);
     }
 
-    public LogicalAggregate<Plan> 
withChildGroupByAndOutputAndSourceRepeat(List<Expression> groupByExprList,
-                                                            
List<NamedExpression> outputExpressionList, Plan newChild,
-                                                            
Optional<LogicalRepeat<? extends Plan>> sourceRepeat) {
+    public LogicalAggregate<Plan> 
withChildGroupByAndOutputAndSourceRepeatAndPartitionExpr(
+            List<Expression> groupByExprList,
+            List<NamedExpression> outputExpressionList, 
Optional<List<Expression>> partitionExpressions, Plan newChild,
+            Optional<LogicalRepeat<? extends Plan>> sourceRepeat) {
         return new LogicalAggregate<>(groupByExprList, outputExpressionList, 
normalized, ordinalIsResolved, generated,
-                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), newChild);
+                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(),
+                partitionExpressions, newChild);
     }
 
     public LogicalAggregate<Plan> withChildAndOutput(CHILD_TYPE child,
                                                        List<NamedExpression> 
outputExpressionList) {
         return new LogicalAggregate<>(groupByExpressions, 
outputExpressionList, normalized, ordinalIsResolved,
                 generated, hasPushed, withInProjection, sourceRepeat, 
Optional.empty(),
-                Optional.empty(), child);
+                Optional.empty(), partitionExpressions, child);
     }
 
     @Override
@@ -374,30 +403,33 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
     @Override
     public LogicalAggregate<CHILD_TYPE> withAggOutput(List<NamedExpression> 
newOutput) {
         return new LogicalAggregate<>(groupByExpressions, newOutput, 
normalized, ordinalIsResolved, generated,
-                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), child());
+                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), partitionExpressions,
+                child());
     }
 
     public LogicalAggregate<Plan> withAggOutputChild(List<NamedExpression> 
newOutput, Plan newChild) {
         return new LogicalAggregate<>(groupByExpressions, newOutput, 
normalized, ordinalIsResolved, generated,
-                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), newChild);
+                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), partitionExpressions,
+                newChild);
     }
 
     public LogicalAggregate<Plan> withNormalized(List<Expression> 
normalizedGroupBy,
             List<NamedExpression> normalizedOutput, Plan normalizedChild) {
         return new LogicalAggregate<>(normalizedGroupBy, normalizedOutput, 
true, ordinalIsResolved, generated,
-                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), normalizedChild);
+                hasPushed, withInProjection, sourceRepeat, Optional.empty(), 
Optional.empty(), partitionExpressions,
+                normalizedChild);
     }
 
     public LogicalAggregate<Plan> withInProjection(boolean withInProjection) {
         return new LogicalAggregate<>(groupByExpressions, outputExpressions, 
normalized, ordinalIsResolved,
                 generated, hasPushed, withInProjection,
-                sourceRepeat, Optional.empty(), Optional.empty(), child());
+                sourceRepeat, Optional.empty(), Optional.empty(), 
partitionExpressions, child());
     }
 
     public LogicalAggregate<Plan> withSourceRepeat(LogicalRepeat<?> 
sourceRepeat) {
         return new LogicalAggregate<>(groupByExpressions, outputExpressions, 
normalized, ordinalIsResolved,
                 generated, hasPushed, withInProjection, 
Optional.ofNullable(sourceRepeat),
-                Optional.empty(), Optional.empty(), child());
+                Optional.empty(), Optional.empty(), partitionExpressions, 
child());
     }
 
     private boolean isUniqueGroupByUnique(NamedExpression namedExpression) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalRepeat.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalRepeat.java
index a7ab86de4fb..7b10cd2ced4 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalRepeat.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalRepeat.java
@@ -57,6 +57,7 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> extends 
LogicalUnary<CHILD_T
     private final List<NamedExpression> outputExpressions;
     private final Optional<SlotReference> groupingId;
     private final boolean withInProjection;
+    private final RepeatType type;
 
     /**
      * Desc: Constructor for LogicalRepeat.
@@ -64,8 +65,9 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> extends 
LogicalUnary<CHILD_T
     public LogicalRepeat(
             List<List<Expression>> groupingSets,
             List<NamedExpression> outputExpressions,
+            RepeatType type,
             CHILD_TYPE child) {
-        this(groupingSets, outputExpressions, Optional.empty(), child);
+        this(groupingSets, outputExpressions, Optional.empty(), type, child);
     }
 
     /**
@@ -75,9 +77,10 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> extends 
LogicalUnary<CHILD_T
             List<List<Expression>> groupingSets,
             List<NamedExpression> outputExpressions,
             SlotReference groupingId,
+            RepeatType type,
             CHILD_TYPE child) {
         this(groupingSets, outputExpressions, Optional.empty(), 
Optional.empty(),
-                Optional.ofNullable(groupingId), true, child);
+                Optional.ofNullable(groupingId), true, type, child);
     }
 
     /**
@@ -87,8 +90,9 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> extends 
LogicalUnary<CHILD_T
             List<List<Expression>> groupingSets,
             List<NamedExpression> outputExpressions,
             Optional<SlotReference> groupingId,
+            RepeatType type,
             CHILD_TYPE child) {
-        this(groupingSets, outputExpressions, Optional.empty(), 
Optional.empty(), groupingId, true, child);
+        this(groupingSets, outputExpressions, Optional.empty(), 
Optional.empty(), groupingId, true, type, child);
     }
 
     /**
@@ -96,7 +100,7 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> extends 
LogicalUnary<CHILD_T
      */
     private LogicalRepeat(List<List<Expression>> groupingSets, 
List<NamedExpression> outputExpressions,
             Optional<GroupExpression> groupExpression, 
Optional<LogicalProperties> logicalProperties,
-            Optional<SlotReference> groupingId, boolean withInProjection, 
CHILD_TYPE child) {
+            Optional<SlotReference> groupingId, boolean withInProjection, 
RepeatType type, CHILD_TYPE child) {
         super(PlanType.LOGICAL_REPEAT, groupExpression, logicalProperties, 
child);
         this.groupingSets = Objects.requireNonNull(groupingSets, "groupingSets 
can not be null")
                 .stream()
@@ -106,6 +110,7 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> extends 
LogicalUnary<CHILD_T
                 Objects.requireNonNull(outputExpressions, "outputExpressions 
can not be null"));
         this.groupingId = groupingId;
         this.withInProjection = withInProjection;
+        this.type = type;
     }
 
     @Override
@@ -122,6 +127,10 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHILD_T
         return groupingId;
     }
 
+    public RepeatType getRepeatType() {
+        return type;
+    }
+
     @Override
     public List<NamedExpression> getOutputs() {
         return outputExpressions;
@@ -217,13 +226,13 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHILD_T
     @Override
     public LogicalRepeat<Plan> withChildren(List<Plan> children) {
         Preconditions.checkArgument(children.size() == 1);
-        return new LogicalRepeat<>(groupingSets, outputExpressions, 
groupingId, children.get(0));
+        return new LogicalRepeat<>(groupingSets, outputExpressions, 
groupingId, type, children.get(0));
     }
 
     @Override
     public LogicalRepeat<CHILD_TYPE> 
withGroupExpression(Optional<GroupExpression> groupExpression) {
         return new LogicalRepeat<>(groupingSets, outputExpressions, 
groupExpression,
-                Optional.of(getLogicalProperties()), groupingId, 
withInProjection, child());
+                Optional.of(getLogicalProperties()), groupingId, 
withInProjection, type, child());
     }
 
     @Override
@@ -231,35 +240,35 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHILD_T
             Optional<LogicalProperties> logicalProperties, List<Plan> 
children) {
         Preconditions.checkArgument(children.size() == 1);
         return new LogicalRepeat<>(groupingSets, outputExpressions, 
groupExpression, logicalProperties,
-                groupingId, withInProjection, children.get(0));
+                groupingId, withInProjection, type, children.get(0));
     }
 
     public LogicalRepeat<CHILD_TYPE> withGroupSets(List<List<Expression>> 
groupingSets) {
-        return new LogicalRepeat<>(groupingSets, outputExpressions, 
groupingId, child());
+        return new LogicalRepeat<>(groupingSets, outputExpressions, 
groupingId, type, child());
     }
 
     public LogicalRepeat<CHILD_TYPE> 
withGroupSetsAndOutput(List<List<Expression>> groupingSets,
             List<NamedExpression> outputExpressionList) {
-        return new LogicalRepeat<>(groupingSets, outputExpressionList, 
groupingId, child());
+        return new LogicalRepeat<>(groupingSets, outputExpressionList, 
groupingId, type, child());
     }
 
     @Override
     public LogicalRepeat<CHILD_TYPE> withAggOutput(List<NamedExpression> 
newOutput) {
-        return new LogicalRepeat<>(groupingSets, newOutput, groupingId, 
child());
+        return new LogicalRepeat<>(groupingSets, newOutput, groupingId, type, 
child());
     }
 
     public LogicalRepeat<Plan> withNormalizedExpr(List<List<Expression>> 
groupingSets,
             List<NamedExpression> outputExpressionList, SlotReference 
groupingId, Plan child) {
-        return new LogicalRepeat<>(groupingSets, outputExpressionList, 
groupingId, child);
+        return new LogicalRepeat<>(groupingSets, outputExpressionList, 
groupingId, type, child);
     }
 
     public LogicalRepeat<Plan> withAggOutputAndChild(List<NamedExpression> 
newOutput, Plan child) {
-        return new LogicalRepeat<>(groupingSets, newOutput, groupingId, child);
+        return new LogicalRepeat<>(groupingSets, newOutput, groupingId, type, 
child);
     }
 
     public LogicalRepeat<CHILD_TYPE> withInProjection(boolean 
withInProjection) {
         return new LogicalRepeat<>(groupingSets, outputExpressions,
-                Optional.empty(), Optional.empty(), groupingId, 
withInProjection, child());
+                Optional.empty(), Optional.empty(), groupingId, 
withInProjection, type, child());
     }
 
     @Override
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
index 75504d3b76e..232ae772a27 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
@@ -839,6 +839,9 @@ public class SessionVariable implements Serializable, 
Writable {
     public static final String SKEW_REWRITE_JOIN_SALT_EXPLODE_FACTOR = 
"skew_rewrite_join_salt_explode_factor";
 
     public static final String SKEW_REWRITE_AGG_BUCKET_NUM = 
"skew_rewrite_agg_bucket_num";
+    public static final String DECOMPOSE_REPEAT_THRESHOLD = 
"decompose_repeat_threshold";
+    public static final String DECOMPOSE_REPEAT_SHUFFLE_INDEX_IN_MAX_GROUP
+            = "decompose_repeat_shuffle_index_in_max_group";
 
     public static final String HOT_VALUE_COLLECT_COUNT = 
"hot_value_collect_count";
     @VariableMgr.VarAttr(name = HOT_VALUE_COLLECT_COUNT, needForward = true,
@@ -3366,6 +3369,11 @@ public class SessionVariable implements Serializable, 
Writable {
     )
     public boolean useV3StorageFormat = false;
 
+    @VariableMgr.VarAttr(name = DECOMPOSE_REPEAT_THRESHOLD)
+    public int decomposeRepeatThreshold = 3;
+    @VariableMgr.VarAttr(name = DECOMPOSE_REPEAT_SHUFFLE_INDEX_IN_MAX_GROUP)
+    public int decomposeRepeatShuffleIndexInMaxGroup = -1;
+
     public static final String IGNORE_ICEBERG_DANGLING_DELETE = 
"ignore_iceberg_dangling_delete";
     @VariableMgr.VarAttr(name = IGNORE_ICEBERG_DANGLING_DELETE,
             description = {"是否忽略 Iceberg 表中 dangling delete 文件对 COUNT(*) 
统计信息的影响。"
@@ -3379,6 +3387,7 @@ public class SessionVariable implements Serializable, 
Writable {
                             + "to exclude the impact of dangling delete 
files."})
     public boolean ignoreIcebergDanglingDelete = false;
 
+
     // If this fe is in fuzzy mode, then will use initFuzzyModeVariables to 
generate some variables,
     // not the default value set in the code.
     @SuppressWarnings("checkstyle:Indentation")
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java
index 556f5279412..e2c874aad17 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java
@@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingId;
 import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.Repeat.RepeatType;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
 import org.apache.doris.nereids.util.MemoPatternMatchSupported;
@@ -44,6 +45,7 @@ public class NormalizeRepeatTest implements 
MemoPatternMatchSupported {
         Plan plan = new LogicalRepeat<>(
                 ImmutableList.of(ImmutableList.of(id), ImmutableList.of(name)),
                 ImmutableList.of(idNotNull, alias),
+                RepeatType.GROUPING_SETS,
                 scan1
         );
         PlanChecker.from(MemoTestUtils.createCascadesContext(plan))
@@ -62,6 +64,7 @@ public class NormalizeRepeatTest implements 
MemoPatternMatchSupported {
         Plan plan = new LogicalRepeat<>(
                 ImmutableList.of(ImmutableList.of(id)),
                 ImmutableList.of(idNotNull, alias),
+                RepeatType.GROUPING_SETS,
                 scan1
         );
         PlanChecker.from(MemoTestUtils.createCascadesContext(plan))
@@ -80,6 +83,7 @@ public class NormalizeRepeatTest implements 
MemoPatternMatchSupported {
         Plan plan = new LogicalRepeat<>(
                 ImmutableList.of(ImmutableList.of(id)),
                 ImmutableList.of(idNotNull, alias),
+                RepeatType.GROUPING_SETS,
                 scan1
         );
         PlanChecker.from(MemoTestUtils.createCascadesContext(plan))
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregationTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregationTest.java
index c78394ce9ef..d1187b0e7cc 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregationTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregationTest.java
@@ -28,6 +28,7 @@ import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio
 import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
 import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.Repeat.RepeatType;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
 import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
@@ -369,6 +370,7 @@ public class DecomposeRepeatWithPreAggregationTest extends 
TestWithFeService imp
                 groupingSets,
                 (List) ImmutableList.of(a, b),
                 new SlotReference("grouping_id", IntegerType.INSTANCE),
+                RepeatType.GROUPING_SETS,
                 emptyRelation);
         LogicalAggregate<Plan> aggregate = new LogicalAggregate<>(
                 ImmutableList.of(a, b),
@@ -459,6 +461,7 @@ public class DecomposeRepeatWithPreAggregationTest extends 
TestWithFeService imp
                 originalGroupingSets,
                 (List) ImmutableList.of(a, b, c),
                 new SlotReference("grouping_id", IntegerType.INSTANCE),
+                RepeatType.GROUPING_SETS,
                 emptyRelation);
 
         List<List<Expression>> newGroupingSets = ImmutableList.of(
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java
index 8bae1713fe1..44f5aa8e6bd 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java
@@ -30,6 +30,7 @@ import 
org.apache.doris.nereids.trees.expressions.functions.agg.Max;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.plans.algebra.Repeat.RepeatType;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
@@ -179,7 +180,7 @@ public class PushDownFilterThroughAggregationTest 
implements MemoPatternMatchSup
         Slot name = scan.getOutput().get(2);
         LogicalRepeat repeatPlan = new LogicalRepeat<>(
                 ImmutableList.of(ImmutableList.of(id, gender), 
ImmutableList.of(id)),
-                ImmutableList.of(id, gender, name), scan);
+                ImmutableList.of(id, gender, name), RepeatType.GROUPING_SETS, 
scan);
         NamedExpression nameMax = new Alias(new Max(name), "nameMax");
 
         final Expression filterPredicateId = new GreaterThan(id, 
Literal.of(1));
@@ -206,7 +207,7 @@ public class PushDownFilterThroughAggregationTest 
implements MemoPatternMatchSup
 
         repeatPlan = new LogicalRepeat<>(
                 ImmutableList.of(ImmutableList.of(id, gender), 
ImmutableList.of(gender)),
-                ImmutableList.of(id, gender, name), scan);
+                ImmutableList.of(id, gender, name), RepeatType.GROUPING_SETS, 
scan);
         plan = new LogicalPlanBuilder(repeatPlan)
                 .aggGroupUsingIndexAndSourceRepeat(ImmutableList.of(0, 1), 
ImmutableList.of(
                         id, gender, nameMax), Optional.of(repeatPlan))
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java
index 4ba8edfe34c..80842f6271b 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java
@@ -22,6 +22,7 @@ import 
org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.Repeat.RepeatType;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
@@ -62,6 +63,7 @@ public class LogicalPlanDeepCopierTest {
                 groupingSets,
                 
scan.getOutput().stream().map(NamedExpression.class::cast).collect(Collectors.toList()),
                 groupingId,
+                RepeatType.GROUPING_SETS,
                 scan
         );
         List<? extends NamedExpression> groupByExprs = 
repeat.getOutput().subList(0, 1).stream()
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/algebra/RepeatTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/algebra/RepeatTest.java
new file mode 100644
index 00000000000..864fcc3e21d
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/algebra/RepeatTest.java
@@ -0,0 +1,206 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.plans.algebra;
+
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingId;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.Repeat.RepeatType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Sets;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Unit tests for {@link Repeat} interface default methods:
+ * toShapes, indexesOfOutput, getGroupingSetsIndexesInOutput, 
computeRepeatSlotIdList.
+ */
+public class RepeatTest {
+
+    private LogicalOlapScan scan;
+    private Slot id;
+    private Slot gender;
+    private Slot name;
+    private Slot age;
+
+    @BeforeEach
+    public void setUp() {
+        scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), 
PlanConstructor.student, ImmutableList.of("db"));
+        id = scan.getOutput().get(0);
+        gender = scan.getOutput().get(1);
+        name = scan.getOutput().get(2);
+        age = scan.getOutput().get(3);
+    }
+
+    @Test
+    public void testToShapes() {
+        // grouping sets: (id, name), (id), ()
+        // flatten = [id, name], shapes: [false,false], [false,true], 
[true,true]
+        List<List<Expression>> groupingSets = ImmutableList.of(
+                ImmutableList.of(id, name),
+                ImmutableList.of(id),
+                ImmutableList.of()
+        );
+        Alias alias = new Alias(new Sum(name), "sum(name)");
+        Repeat<Plan> repeat = new LogicalRepeat<>(
+                groupingSets,
+                ImmutableList.of(id, name, alias),
+                RepeatType.GROUPING_SETS,
+                scan
+        );
+
+        Repeat.GroupingSetShapes shapes = repeat.toShapes();
+
+        Assertions.assertEquals(2, shapes.flattenGroupingSetExpression.size());
+        
Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(id));
+        
Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(name));
+        Assertions.assertEquals(3, shapes.shapes.size());
+
+        // (id, name) -> [false, false]
+        Assertions.assertFalse(shapes.shapes.get(0).shouldBeErasedToNull(0));
+        Assertions.assertFalse(shapes.shapes.get(0).shouldBeErasedToNull(1));
+        Assertions.assertEquals(0L, shapes.shapes.get(0).computeLongValue());
+
+        // (id) -> [false, true] (id in set, name not)
+        Assertions.assertFalse(shapes.shapes.get(1).shouldBeErasedToNull(0));
+        Assertions.assertTrue(shapes.shapes.get(1).shouldBeErasedToNull(1));
+        Assertions.assertEquals(1L, shapes.shapes.get(1).computeLongValue());
+
+        // () -> [true, true]
+        Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(0));
+        Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(1));
+        Assertions.assertEquals(3L, shapes.shapes.get(2).computeLongValue());
+    }
+
+    @Test
+    public void testToShapesWithGroupingFunction() {
+        // grouping(id) adds id to flatten if not present; single set (name) 
has flatten [name, id]
+        List<List<Expression>> groupingSets = ImmutableList.of(
+                ImmutableList.of(name), ImmutableList.of(name, id), 
ImmutableList.of());
+        Alias groupingAlias = new Alias(new GroupingId(gender, age), 
"grouping_id(id)");
+        Repeat<Plan> repeat = new LogicalRepeat<>(
+                groupingSets,
+                ImmutableList.of(name, groupingAlias),
+                RepeatType.GROUPING_SETS,
+                scan
+        );
+
+        Repeat.GroupingSetShapes shapes = repeat.toShapes();
+
+        // flatten = [name] from getGroupBy + [id] from grouping function arg
+        Assertions.assertEquals(4, shapes.flattenGroupingSetExpression.size());
+        
Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(name));
+        
Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(id));
+        
Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(gender));
+        
Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(age));
+
+        Assertions.assertEquals(3, shapes.shapes.size());
+        // (name) -> name not erased, id,gender,age erased
+        Assertions.assertFalse(shapes.shapes.get(0).shouldBeErasedToNull(0));
+        Assertions.assertTrue(shapes.shapes.get(0).shouldBeErasedToNull(1));
+        Assertions.assertTrue(shapes.shapes.get(0).shouldBeErasedToNull(2));
+        Assertions.assertTrue(shapes.shapes.get(0).shouldBeErasedToNull(3));
+        // (name, id) -> name,id not erased, gender and age erased
+        Assertions.assertFalse(shapes.shapes.get(1).shouldBeErasedToNull(0));
+        Assertions.assertFalse(shapes.shapes.get(1).shouldBeErasedToNull(1));
+        Assertions.assertTrue(shapes.shapes.get(1).shouldBeErasedToNull(2));
+        Assertions.assertTrue(shapes.shapes.get(1).shouldBeErasedToNull(3));
+        //() -> all erased
+        Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(0));
+        Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(1));
+        Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(2));
+        Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(3));
+    }
+
+    @Test
+    public void testIndexesOfOutput() {
+        List<Slot> outputSlots = ImmutableList.of(id, gender, name, age);
+        Map<Expression, Integer> indexes = Repeat.indexesOfOutput(outputSlots);
+        Assertions.assertEquals(4, indexes.size());
+        Assertions.assertEquals(0, indexes.get(id));
+        Assertions.assertEquals(1, indexes.get(gender));
+        Assertions.assertEquals(2, indexes.get(name));
+        Assertions.assertEquals(3, indexes.get(age));
+    }
+
+    @Test
+    public void testGetGroupingSetsIndexesInOutput() {
+        // groupingSets=((name, id), (id), (gender)), output=[id, name, gender]
+        // expect:((1,0),(0),(2))
+        List<List<Expression>> groupingSets = ImmutableList.of(
+                ImmutableList.of(name, id),
+                ImmutableList.of(id),
+                ImmutableList.of(gender)
+        );
+        Alias groupingId = new Alias(new GroupingId(id, name));
+        Repeat<Plan> repeat = new LogicalRepeat<>(
+                groupingSets,
+                ImmutableList.of(id, name, gender, groupingId),
+                RepeatType.GROUPING_SETS,
+                scan
+        );
+        List<Slot> outputSlots = ImmutableList.of(id, name, gender, 
groupingId.toSlot());
+
+        List<Set<Integer>> result = 
repeat.getGroupingSetsIndexesInOutput(outputSlots);
+
+        Assertions.assertEquals(3, result.size());
+        // (name, id) -> indexes {1, 0}
+        Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(1, 0)), 
result.get(0));
+        // (id) -> index {0}
+        Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(0)), 
result.get(1));
+        // (gender) -> index {2}
+        Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(2)), 
result.get(2));
+    }
+
+    @Test
+    public void testComputeRepeatSlotIdList() {
+        // groupingSets=((name, id), (id)), output=[id, name], slotIdList=[3, 
4] (id->3, name->4)
+        List<List<Expression>> groupingSets = ImmutableList.of(
+                ImmutableList.of(name, id),
+                ImmutableList.of(id)
+        );
+        Repeat<Plan> repeat = new LogicalRepeat<>(
+                groupingSets,
+                ImmutableList.of(id, name),
+                RepeatType.GROUPING_SETS,
+                scan
+        );
+        List<Slot> outputSlots = ImmutableList.of(id, name);
+        List<Integer> slotIdList = ImmutableList.of(3, 4);
+
+        List<Set<Integer>> result = repeat.computeRepeatSlotIdList(slotIdList, 
outputSlots);
+
+        Assertions.assertEquals(2, result.size());
+        // (name, id) -> indexes {1,0} -> slot ids {4, 3}
+        Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(4, 3)), 
result.get(0));
+        // (id) -> index {0} -> slot id {3}
+        Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(3)), 
result.get(1));
+    }
+}
diff --git a/regression-test/data/nereids_p0/repeat/test_repeat_output_slot.out 
b/regression-test/data/nereids_p0/repeat/test_repeat_output_slot.out
index e6516a0d47c..f8ab9595435 100644
--- a/regression-test/data/nereids_p0/repeat/test_repeat_output_slot.out
+++ b/regression-test/data/nereids_p0/repeat/test_repeat_output_slot.out
@@ -37,7 +37,6 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
 100000
 100000
 100000
-100000
 
 -- !sql_2_shape --
 PhysicalCteAnchor ( cteId=CTEId#0 )
@@ -60,11 +59,9 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
 -- !sql_2_result --
 \N     ALL     1       6       \N      \N      \N
 \N     ALL     1       6       \N      \N      \N
-2020-01-02T00:00       ALL     1       6       \N      2020-01-02T00:00        
\N
-2020-01-02T00:00       ALL     1       6       \N      2020-01-02T00:00        
\N
-2020-01-03T00:00       ALL     1       6       \N      2020-01-03T00:00        
\N
-2020-01-03T00:00       ALL     1       6       \N      2020-01-03T00:00        
\N
-2020-01-04T00:00       ALL     1       6       \N      2020-01-04T00:00        
\N
-2020-01-04T00:00       ALL     1       6       \N      2020-01-04T00:00        
\N
+2020-01-04T00:00       ALL     1       6       \N      \N      a
+2020-01-04T00:00       ALL     1       6       \N      \N      a
+2020-01-04T00:00       ALL     1       6       \N      \N      b
+2020-01-04T00:00       ALL     1       6       \N      \N      b
 2020-01-04T00:00       ALL     1       7       \N      \N      \N
 
diff --git 
a/regression-test/data/nereids_rules_p0/decompose_repeat/decompose_repeat.out 
b/regression-test/data/nereids_rules_p0/decompose_repeat/decompose_repeat.out
index 919738109c1..d5245f25084 100644
--- 
a/regression-test/data/nereids_rules_p0/decompose_repeat/decompose_repeat.out
+++ 
b/regression-test/data/nereids_rules_p0/decompose_repeat/decompose_repeat.out
@@ -369,3 +369,187 @@
 1      3       2       \N      2       0
 1      3       2       2       2       0
 
+-- !grouping_only_in_max --
+\N     \N      \N      1
+1      \N      \N      1
+1      2       \N      1
+1      2       1       0
+1      2       3       0
+1      3       \N      1
+1      3       2       0
+
+-- !grouping_id_only_in_max_c_d --
+\N     \N      \N      15
+1      \N      \N      7
+1      2       \N      3
+1      2       1       0
+1      2       3       0
+1      2       3       0
+1      3       \N      3
+1      3       2       0
+
+-- !grouping_id_only_in_max_d --
+\N     \N      \N      15
+1      \N      \N      7
+1      2       1       0
+1      2       1       1
+1      2       3       0
+1      2       3       0
+1      2       3       1
+1      3       2       0
+1      3       2       1
+
+-- !multi_grouping_func --
+\N     \N      \N      \N      7       7       7       3
+1      \N      \N      \N      3       6       5       0
+1      2       1       \N      0       0       0       0
+1      2       1       1       0       0       0       0
+1      2       3       \N      0       0       0       0
+1      2       3       3       0       0       0       0
+1      2       3       4       0       0       0       0
+1      3       2       \N      0       0       0       0
+1      3       2       2       0       0       0       0
+
+-- !grouping_partial_only_in_max --
+\N     \N      \N      \N      7
+1      2       \N      \N      3
+1      2       1       \N      1
+1      2       1       1       0
+1      2       3       \N      1
+1      2       3       3       0
+1      2       3       4       0
+1      3       \N      \N      3
+1      3       2       \N      1
+1      3       2       2       0
+
+-- !mixed_grouping_func_1 --
+\N     \N      \N      \N      1       7
+1      \N      \N      \N      0       7
+1      2       1       \N      0       1
+1      2       1       1       0       0
+1      2       3       \N      0       1
+1      2       3       3       0       0
+1      2       3       4       0       0
+1      3       2       \N      0       1
+1      3       2       2       0       0
+
+-- !grouping_all_in_other --
+\N     \N      \N      \N      3
+1      \N      \N      \N      1
+1      2       \N      \N      0
+1      2       1       \N      0
+1      2       1       1       0
+1      2       3       \N      0
+1      2       3       3       0
+1      2       3       4       0
+1      3       \N      \N      0
+1      3       2       \N      0
+1      3       2       2       0
+
+-- !grouping_dup_col --
+\N     \N      \N      \N      31
+1      \N      \N      \N      10
+1      2       1       \N      0
+1      2       1       1       0
+1      2       3       \N      0
+1      2       3       3       0
+1      2       3       4       0
+1      3       2       \N      0
+1      3       2       2       0
+
+-- !mixed_grouping_both --
+\N     \N      \N      \N      1       1       7       3
+1      \N      \N      \N      0       1       3       3
+1      2       1       \N      0       0       0       1
+1      2       1       1       0       0       0       0
+1      2       3       \N      0       0       0       1
+1      2       3       3       0       0       0       0
+1      2       3       4       0       0       0       0
+1      3       2       \N      0       0       0       1
+1      3       2       2       0       0       0       0
+
+-- !grouping_different_pos --
+\N     \N      \N      \N      3
+1      \N      1       \N      3
+1      \N      2       \N      3
+1      \N      3       \N      3
+1      2       \N      \N      1
+1      2       1       1       0
+1      2       3       3       0
+1      2       3       4       0
+1      3       \N      \N      1
+1      3       2       2       0
+
+-- !grouping_nested_case --
+\N     \N      \N      \N      0
+1      \N      \N      \N      0
+1      2       1       \N      0
+1      2       1       1       1
+1      2       3       \N      0
+1      2       3       3       1
+1      2       3       4       1
+1      3       2       \N      0
+1      3       2       2       1
+
+-- !grouping_mixed_params_1 --
+\N     \N      \N      \N      7
+1      \N      \N      \N      3
+1      2       \N      \N      1
+1      2       1       \N      1
+1      2       1       1       0
+1      2       3       \N      1
+1      2       3       3       0
+1      2       3       4       0
+1      3       \N      \N      1
+1      3       2       \N      1
+1      3       2       2       0
+
+-- !grouping_single_param_multi --
+\N     \N      \N      \N      1
+1      \N      1       \N      0
+1      \N      2       \N      0
+1      \N      3       \N      0
+1      2       1       \N      0
+1      2       1       1       0
+1      2       3       \N      0
+1      2       3       3       0
+1      2       3       4       0
+1      3       2       \N      0
+1      3       2       2       0
+
+-- !grouping_multi_combinations --
+\N     \N      \N      \N      1       3       7       15
+1      \N      \N      \N      0       1       3       7
+1      2       \N      \N      0       0       1       3
+1      2       1       \N      0       0       0       1
+1      2       1       1       0       0       0       0
+1      2       3       \N      0       0       0       1
+1      2       3       3       0       0       0       0
+1      2       3       4       0       0       0       0
+1      3       \N      \N      0       0       1       3
+1      3       2       \N      0       0       0       1
+1      3       2       2       0       0       0       0
+
+-- !grouping_max_not_first --
+\N     \N      \N      \N      3
+1      2       \N      \N      3
+1      2       1       \N      1
+1      2       1       1       0
+1      2       3       \N      1
+1      2       3       3       0
+1      2       3       4       0
+1      3       \N      \N      3
+1      3       2       \N      1
+1      3       2       2       0
+
+-- !grouping_with_agg --
+\N     \N      \N      \N      10      7
+1      \N      \N      \N      10      3
+1      2       1       \N      1       0
+1      2       1       1       1       0
+1      2       3       \N      7       0
+1      2       3       3       3       0
+1      2       3       4       4       0
+1      3       2       \N      2       0
+1      3       2       2       2       0
+
diff --git 
a/regression-test/suites/nereids_rules_p0/decompose_repeat/decompose_repeat.groovy
 
b/regression-test/suites/nereids_rules_p0/decompose_repeat/decompose_repeat.groovy
index 338517afbc4..5d06776679b 100644
--- 
a/regression-test/suites/nereids_rules_p0/decompose_repeat/decompose_repeat.groovy
+++ 
b/regression-test/suites/nereids_rules_p0/decompose_repeat/decompose_repeat.groovy
@@ -71,4 +71,36 @@ suite("decompose_repeat") {
     order_qt_cube "select a,b,c,d,sum(d),grouping_id(a) from t1 group by 
cube(a,b,c,d)"
     order_qt_cube_add "select a,b,c,d,sum(d)+100+grouping_id(a) from t1 group 
by cube(a,b,c,d);"
     order_qt_cube_sum_parm_add "select a,b,c,d,sum(a+1),grouping_id(a) from t1 
group by cube(a,b,c,d);"
+
+    // grouping scalar functions add more test
+    order_qt_grouping_only_in_max "select a,b,c, grouping(c) from t1 group by 
grouping sets((a,b,c),(a,b),(a),());"
+    order_qt_grouping_id_only_in_max_c_d "select a,b,c, grouping_id(a,b,c,d) 
from t1 group by grouping sets((a,b,c,d),(a,b),(a),());"
+    order_qt_grouping_id_only_in_max_d "select a,b,c, grouping_id(a,b,c,d) 
from t1 group by grouping sets((a,b,c,d),(a,b,c),(a),());"
+    order_qt_multi_grouping_func "select a,b,c,d, grouping_id(a,b,c), 
grouping_id(c,b,a), grouping_id(c,a,b), grouping_id(a,a) from t1 group by 
grouping sets((a,b,c,d),(a,b,c),(a),());"
+    
+    // more test cases for grouping scalar function bug(added by ai)
+    // Test case: grouping function with partial parameters only in max group
+    order_qt_grouping_partial_only_in_max "select a,b,c,d, grouping_id(a,c,d) 
from t1 group by grouping sets((a,b,c,d),(a,b,c),(a,b),());"
+    // Test case: multiple grouping functions, some can optimize and some 
cannot
+    order_qt_mixed_grouping_func_1 "select a,b,c,d, grouping(a), 
grouping_id(b,c,d) from t1 group by grouping sets((a,b,c,d),(a,b,c),(a),());"
+    // Test case: grouping function with all parameters exist in other groups 
(should optimize)
+    order_qt_grouping_all_in_other "select a,b,c,d, grouping_id(a,b) from t1 
group by grouping sets((a,b,c,d),(a,b,c),(a,b),(a),());"
+    // Test case: grouping function with same column repeated
+    order_qt_grouping_dup_col "select a,b,c,d, grouping_id(a,b,a,c,a) from t1 
group by grouping sets((a,b,c,d),(a,b,c),(a),());"
+    // Test case: both grouping and grouping_id with different parameters
+    order_qt_mixed_grouping_both "select a,b,c,d, grouping(a), grouping(b), 
grouping_id(a,b,c), grouping_id(c,d) from t1 group by grouping 
sets((a,b,c,d),(a,b,c),(a),());"
+    // Test case: grouping function with columns in different positions
+    order_qt_grouping_different_pos "select a,b,c,d, grouping_id(b,d) from t1 
group by grouping sets((a,b,c,d),(a,b),(a,c),());"
+    // Test case: nested case with grouping functions that reference only-max 
columns
+    order_qt_grouping_nested_case "select a,b,c,d, case when grouping(d) = 1 
then 0 else 1 end from t1 group by grouping sets((a,b,c,d),(a,b,c),(a),());"
+    // Test case: grouping function parameter mix - one in max only, others in 
all groups
+    order_qt_grouping_mixed_params_1 "select a,b,c,d, grouping_id(a,b,d) from 
t1 group by grouping sets((a,b,c,d),(a,b,c),(a,b),(a),());"
+    // Test case: grouping function with single parameter that exists in 
multiple groups
+    order_qt_grouping_single_param_multi "select a,b,c,d, grouping(c) from t1 
group by grouping sets((a,b,c,d),(a,b,c),(a,c),());"
+    // Test case: multiple grouping_id functions with different parameter 
combinations
+    order_qt_grouping_multi_combinations "select a,b,c,d, grouping_id(a), 
grouping_id(a,b), grouping_id(a,b,c), grouping_id(a,b,c,d) from t1 group by 
grouping sets((a,b,c,d),(a,b,c),(a,b),(a),());"
+    // Test case: grouping function where max group is not first
+    order_qt_grouping_max_not_first "select a,b,c,d, grouping_id(c,d) from t1 
group by grouping sets((a,b),(a,b,c),(a,b,c,d),());"
+    // Test case: complex case with aggregation function and grouping function
+    order_qt_grouping_with_agg "select a,b,c,d, sum(d), grouping_id(a,b,c) 
from t1 group by grouping sets((a,b,c,d),(a,b,c),(a),());"
 }
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to