This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new d705eb05382 branch-3.1: [enhance](nereids) add rule MultiDistinctSplit 
#45209 (#51964)
d705eb05382 is described below

commit d705eb0538222cb35c2ba7721ef29a9b522a510a
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Fri Jun 20 12:07:43 2025 +0800

    branch-3.1: [enhance](nereids) add rule MultiDistinctSplit #45209 (#51964)
    
    Cherry-picked from #45209
    
    Co-authored-by: feiniaofeiafei <[email protected]>
---
 .../doris/nereids/jobs/executor/Rewriter.java      |   4 +
 .../org/apache/doris/nereids/rules/RuleType.java   |   1 +
 .../nereids/rules/analysis/CheckAnalysis.java      |  31 ---
 .../rules/implementation/AggregateStrategies.java  |  13 +-
 .../nereids/rules/rewrite/CheckMultiDistinct.java  |  31 +++
 .../nereids/rules/rewrite/SplitMultiDistinct.java  | 291 +++++++++++++++++++++
 .../trees/expressions/functions/agg/Count.java     |   8 +-
 .../expressions/functions/agg/GroupConcat.java     |   3 +-
 .../trees/expressions/functions/agg/Sum.java       |   3 +-
 .../trees/expressions/functions/agg/Sum0.java      |   3 +-
 .../functions/agg/SupportMultiDistinct.java        |  25 ++
 .../rules/rewrite/SplitMultiDistinctTest.java      | 191 ++++++++++++++
 .../distinct_split/disitinct_split.out             | Bin 0 -> 9519 bytes
 .../distinct_split/disitinct_split.groovy          | 210 +++++++++++++++
 .../aggregate_without_roll_up.groovy               |   4 +-
 .../mv/dimension/dimension_1.groovy                |   2 +-
 .../mv/dimension/dimension_2_3.groovy              |   2 +-
 .../mv/dimension/dimension_2_4.groovy              |   2 +-
 .../nereids_syntax_p0/aggregate_strategies.groovy  |   6 -
 19 files changed, 774 insertions(+), 56 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 57df7f9999b..619ca37e77d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -134,6 +134,7 @@ import 
org.apache.doris.nereids.rules.rewrite.RewriteCteChildren;
 import org.apache.doris.nereids.rules.rewrite.SetPreAggStatus;
 import org.apache.doris.nereids.rules.rewrite.SimplifyWindowExpression;
 import org.apache.doris.nereids.rules.rewrite.SplitLimit;
+import org.apache.doris.nereids.rules.rewrite.SplitMultiDistinct;
 import org.apache.doris.nereids.rules.rewrite.SumLiteralRewrite;
 import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinAgg;
 import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinAggProject;
@@ -552,6 +553,9 @@ public class Rewriter extends AbstractBatchJobExecutor {
                     rewriteJobs.addAll(jobs(topic("or expansion",
                             custom(RuleType.OR_EXPANSION, () -> 
OrExpansion.INSTANCE))));
                 }
+                rewriteJobs.addAll(jobs(topic("split multi distinct",
+                        custom(RuleType.SPLIT_MULTI_DISTINCT, () -> 
SplitMultiDistinct.INSTANCE))));
+
                 if (needSubPathPushDown) {
                     rewriteJobs.addAll(jobs(
                             topic("variant element_at push down",
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index f64d3950694..4f222db3f96 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -298,6 +298,7 @@ public enum RuleType {
     MERGE_TOP_N(RuleTypeClass.REWRITE),
     BUILD_AGG_FOR_UNION(RuleTypeClass.REWRITE),
     COUNT_DISTINCT_REWRITE(RuleTypeClass.REWRITE),
+    SPLIT_MULTI_DISTINCT(RuleTypeClass.REWRITE),
     INNER_TO_CROSS_JOIN(RuleTypeClass.REWRITE),
     CROSS_TO_INNER_JOIN(RuleTypeClass.REWRITE),
     PRUNE_EMPTY_PARTITION(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java
index 7ca8637446b..13455720b07 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java
@@ -21,7 +21,6 @@ import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.OrderExpression;
 import org.apache.doris.nereids.trees.expressions.WindowExpression;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import 
org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
@@ -139,36 +138,6 @@ public class CheckAnalysis implements AnalysisRuleFactory {
     }
 
     private void checkAggregate(LogicalAggregate<? extends Plan> aggregate) {
-        Set<AggregateFunction> aggregateFunctions = 
aggregate.getAggregateFunctions();
-        boolean distinctMultiColumns = false;
-        for (AggregateFunction func : aggregateFunctions) {
-            if (!func.isDistinct()) {
-                continue;
-            }
-            if (func.arity() <= 1) {
-                continue;
-            }
-            for (int i = 1; i < func.arity(); i++) {
-                if (!func.child(i).getInputSlots().isEmpty() && 
!(func.child(i) instanceof OrderExpression)) {
-                    // think about group_concat(distinct col_1, ',')
-                    distinctMultiColumns = true;
-                    break;
-                }
-            }
-            if (distinctMultiColumns) {
-                break;
-            }
-        }
-
-        long distinctFunctionNum = 0;
-        for (AggregateFunction aggregateFunction : aggregateFunctions) {
-            distinctFunctionNum += aggregateFunction.isDistinct() ? 1 : 0;
-        }
-
-        if (distinctMultiColumns && distinctFunctionNum > 1) {
-            throw new AnalysisException(
-                    "The query contains multi count distinct or sum distinct, 
each can't have multi columns");
-        }
         for (Expression expr : aggregate.getGroupByExpressions()) {
             if (expr.anyMatch(AggregateFunction.class::isInstance)) {
                 throw new AnalysisException(
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
index ecd951f1eb7..fa511631fc6 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
@@ -52,9 +52,9 @@ import 
org.apache.doris.nereids.trees.expressions.functions.agg.Count;
 import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
-import 
org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
@@ -1811,15 +1811,8 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
     }
 
     private AggregateFunction tryConvertToMultiDistinct(AggregateFunction 
function) {
-        if (function instanceof Count && function.isDistinct()) {
-            return new MultiDistinctCount(function.getArgument(0),
-                    function.getArguments().subList(1, 
function.arity()).toArray(new Expression[0]));
-        } else if (function instanceof Sum && function.isDistinct()) {
-            return ((Sum) function).convertToMultiDistinct();
-        } else if (function instanceof Sum0 && function.isDistinct()) {
-            return ((Sum0) function).convertToMultiDistinct();
-        } else if (function instanceof GroupConcat && function.isDistinct()) {
-            return ((GroupConcat) function).convertToMultiDistinct();
+        if (function instanceof SupportMultiDistinct && function.isDistinct()) 
{
+            return ((SupportMultiDistinct) function).convertToMultiDistinct();
         }
         return function;
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java
index 4488a94b8d1..dd76457c411 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.rewrite;
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.OrderExpression;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
@@ -57,6 +58,36 @@ public class CheckMultiDistinct extends 
OneRewriteRuleFactory {
                 }
             }
         }
+
+        boolean distinctMultiColumns = false;
+        for (AggregateFunction func : aggregate.getAggregateFunctions()) {
+            if (!func.isDistinct()) {
+                continue;
+            }
+            if (func.arity() <= 1) {
+                continue;
+            }
+            for (int i = 1; i < func.arity(); i++) {
+                if (!func.child(i).getInputSlots().isEmpty() && 
!(func.child(i) instanceof OrderExpression)) {
+                    // think about group_concat(distinct col_1, ',')
+                    distinctMultiColumns = true;
+                    break;
+                }
+            }
+            if (distinctMultiColumns) {
+                break;
+            }
+        }
+
+        long distinctFunctionNum = 0;
+        for (AggregateFunction aggregateFunction : 
aggregate.getAggregateFunctions()) {
+            distinctFunctionNum += aggregateFunction.isDistinct() ? 1 : 0;
+        }
+
+        if (distinctMultiColumns && distinctFunctionNum > 1) {
+            throw new AnalysisException(
+                    "The query contains multi count distinct or sum distinct, 
each can't have multi columns");
+        }
         return aggregate;
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinct.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinct.java
new file mode 100644
index 00000000000..56df1485a83
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinct.java
@@ -0,0 +1,291 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.CascadesContext;
+import org.apache.doris.nereids.StatementContext;
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.SplitMultiDistinct.DistinctSplitContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.OrderExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * LogicalAggregate(output:count(distinct a) as c1, count(distinct b) as c2)
+ *   +--Plan
+ * ->
+ * LogicalCTEAnchor
+ *   +--LogicalCTEProducer
+ *     +--Plan
+ *   +--LogicalProject(c1, c2)
+ *     +--LogicalJoin
+ *       +--LogicalAggregate(output:count(distinct a))
+ *         +--LogicalCTEConsumer
+ *       +--LogicalAggregate(output:count(distinct b))
+ *         +--LogicalCTEConsumer
+ * */
+public class SplitMultiDistinct extends 
DefaultPlanRewriter<DistinctSplitContext> implements CustomRewriter {
+    public static SplitMultiDistinct INSTANCE = new SplitMultiDistinct();
+
+    /**DistinctSplitContext*/
+    public static class DistinctSplitContext {
+        List<LogicalCTEProducer<? extends Plan>> cteProducerList;
+        StatementContext statementContext;
+        CascadesContext cascadesContext;
+
+        public DistinctSplitContext(StatementContext statementContext, 
CascadesContext cascadesContext) {
+            this.statementContext = statementContext;
+            this.cteProducerList = new ArrayList<>();
+            this.cascadesContext = cascadesContext;
+        }
+    }
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        DistinctSplitContext ctx = new DistinctSplitContext(
+                jobContext.getCascadesContext().getStatementContext(), 
jobContext.getCascadesContext());
+        plan = plan.accept(this, ctx);
+        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
ctx.cteProducerList.get(i);
+            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
+        }
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalCTEAnchor(
+            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, 
DistinctSplitContext ctx) {
+        Plan child1 = anchor.child(0).accept(this, ctx);
+        DistinctSplitContext consumerContext =
+                new DistinctSplitContext(ctx.statementContext, 
ctx.cascadesContext);
+        Plan child2 = anchor.child(1).accept(this, consumerContext);
+        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
consumerContext.cteProducerList.get(i);
+            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, 
child2);
+        }
+        return anchor.withChildren(ImmutableList.of(child1, child2));
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, 
DistinctSplitContext ctx) {
+        Plan newChild = agg.child().accept(this, ctx);
+        agg = agg.withChildren(ImmutableList.of(newChild));
+        List<Alias> distinctFuncWithAlias = new ArrayList<>();
+        List<Alias> otherAggFuncs = new ArrayList<>();
+        if (!needTransform((LogicalAggregate<Plan>) agg, 
distinctFuncWithAlias, otherAggFuncs)) {
+            return agg;
+        }
+
+        LogicalAggregate<Plan> cloneAgg = (LogicalAggregate<Plan>) 
LogicalPlanDeepCopier.INSTANCE
+                .deepCopy(agg, new DeepCopierContext());
+        LogicalCTEProducer<Plan> producer = new 
LogicalCTEProducer<>(ctx.statementContext.getNextCTEId(),
+                cloneAgg.child());
+        ctx.cteProducerList.add(producer);
+        Map<Slot, Slot> originToProducerSlot = new HashMap<>();
+        for (int i = 0; i < agg.child().getOutput().size(); ++i) {
+            Slot originSlot = agg.child().getOutput().get(i);
+            Slot cloneSlot = cloneAgg.child().getOutput().get(i);
+            originToProducerSlot.put(originSlot, cloneSlot);
+        }
+        distinctFuncWithAlias = ExpressionUtils.replace((List) 
distinctFuncWithAlias, originToProducerSlot);
+        otherAggFuncs = ExpressionUtils.replace((List) otherAggFuncs, 
originToProducerSlot);
+        // construct cte consumer and aggregate
+        List<LogicalAggregate<Plan>> newAggs = new ArrayList<>();
+        // All otherAggFuncs are placed in the first one
+        Map<Alias, Alias> newToOriginDistinctFuncAlias = new HashMap<>();
+        List<Expression> outputJoinGroupBys = new ArrayList<>();
+        for (int i = 0; i < distinctFuncWithAlias.size(); ++i) {
+            Expression distinctAggFunc = distinctFuncWithAlias.get(i).child(0);
+            Map<Slot, Slot> producerToConsumerSlotMap = new HashMap<>();
+            List<NamedExpression> outputExpressions = new ArrayList<>();
+            List<Expression> replacedGroupBy = new ArrayList<>();
+            LogicalCTEConsumer consumer = 
constructConsumerAndReplaceGroupBy(ctx, producer, cloneAgg, outputExpressions,
+                    producerToConsumerSlotMap, replacedGroupBy);
+            Expression newDistinctAggFunc = 
ExpressionUtils.replace(distinctAggFunc, producerToConsumerSlotMap);
+            Alias alias = new Alias(newDistinctAggFunc);
+            outputExpressions.add(alias);
+            if (i == 0) {
+                // save replacedGroupBy
+                outputJoinGroupBys.addAll(replacedGroupBy);
+            }
+            LogicalAggregate<Plan> newAgg = new 
LogicalAggregate<>(replacedGroupBy, outputExpressions, consumer);
+            newAggs.add(newAgg);
+            newToOriginDistinctFuncAlias.put(alias, 
distinctFuncWithAlias.get(i));
+        }
+        buildOtherAggFuncAggregate(otherAggFuncs, producer, ctx, cloneAgg, 
newToOriginDistinctFuncAlias, newAggs);
+        List<Expression> groupBy = agg.getGroupByExpressions();
+        LogicalJoin<Plan, Plan> join = constructJoin(newAggs, groupBy);
+        return constructProject(groupBy, newToOriginDistinctFuncAlias, 
outputJoinGroupBys, join);
+    }
+
+    private static void buildOtherAggFuncAggregate(List<Alias> otherAggFuncs, 
LogicalCTEProducer<Plan> producer,
+            DistinctSplitContext ctx, LogicalAggregate<Plan> cloneAgg, 
Map<Alias, Alias> newToOriginDistinctFuncAlias,
+            List<LogicalAggregate<Plan>> newAggs) {
+        if (otherAggFuncs.isEmpty()) {
+            return;
+        }
+        Map<Slot, Slot> producerToConsumerSlotMap = new HashMap<>();
+        List<NamedExpression> outputExpressions = new ArrayList<>();
+        List<Expression> replacedGroupBy = new ArrayList<>();
+        LogicalCTEConsumer consumer = constructConsumerAndReplaceGroupBy(ctx, 
producer, cloneAgg, outputExpressions,
+                producerToConsumerSlotMap, replacedGroupBy);
+        List<Expression> otherAggFuncAliases = otherAggFuncs.stream()
+                .map(e -> ExpressionUtils.replace(e, 
producerToConsumerSlotMap)).collect(Collectors.toList());
+        for (Expression otherAggFuncAlias : otherAggFuncAliases) {
+            // otherAggFunc is instance of Alias
+            Alias outputOtherFunc = new Alias(otherAggFuncAlias.child(0));
+            outputExpressions.add(outputOtherFunc);
+            newToOriginDistinctFuncAlias.put(outputOtherFunc, (Alias) 
otherAggFuncAlias);
+        }
+        LogicalAggregate<Plan> newAgg = new 
LogicalAggregate<>(replacedGroupBy, outputExpressions, consumer);
+        newAggs.add(newAgg);
+    }
+
+    private static LogicalCTEConsumer 
constructConsumerAndReplaceGroupBy(DistinctSplitContext ctx,
+            LogicalCTEProducer<Plan> producer, LogicalAggregate<Plan> 
cloneAgg, List<NamedExpression> outputExpressions,
+            Map<Slot, Slot> producerToConsumerSlotMap, List<Expression> 
replacedGroupBy) {
+        LogicalCTEConsumer consumer = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+        ctx.cascadesContext.putCTEIdToConsumer(consumer);
+        for (Map.Entry<Slot, Slot> entry : 
consumer.getConsumerToProducerOutputMap().entrySet()) {
+            producerToConsumerSlotMap.put(entry.getValue(), entry.getKey());
+        }
+        
replacedGroupBy.addAll(ExpressionUtils.replace(cloneAgg.getGroupByExpressions(),
 producerToConsumerSlotMap));
+        
outputExpressions.addAll(replacedGroupBy.stream().map(Slot.class::cast).collect(Collectors.toList()));
+        return consumer;
+    }
+
+    private static boolean isDistinctMultiColumns(AggregateFunction func) {
+        if (func.arity() <= 1) {
+            return false;
+        }
+        for (int i = 1; i < func.arity(); ++i) {
+            // think about group_concat(distinct col_1, ',')
+            if (!(func.child(i) instanceof OrderExpression) && 
!func.child(i).getInputSlots().isEmpty()) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    private static boolean needTransform(LogicalAggregate<Plan> agg, 
List<Alias> aliases, List<Alias> otherAggFuncs) {
+        // TODO with source repeat aggregate need to be supported in future
+        if (agg.getSourceRepeat().isPresent()) {
+            return false;
+        }
+        Set<Expression> distinctFunc = new HashSet<>();
+        boolean distinctMultiColumns = false;
+        for (NamedExpression namedExpression : agg.getOutputExpressions()) {
+            if (!(namedExpression instanceof Alias) || 
!(namedExpression.child(0) instanceof AggregateFunction)) {
+                continue;
+            }
+            AggregateFunction aggFunc = (AggregateFunction) 
namedExpression.child(0);
+            if (aggFunc instanceof SupportMultiDistinct && 
aggFunc.isDistinct()) {
+                aliases.add((Alias) namedExpression);
+                distinctFunc.add(aggFunc);
+                distinctMultiColumns = distinctMultiColumns || 
isDistinctMultiColumns(aggFunc);
+            } else {
+                otherAggFuncs.add((Alias) namedExpression);
+            }
+        }
+        if (distinctFunc.size() <= 1) {
+            return false;
+        }
+        // when this aggregate is not distinctMultiColumns, and group by 
expressions is not empty
+        // e.g. sql1: select count(distinct a), count(distinct b) from t1 
group by c;
+        // sql2: select count(distinct a) from t1 group by c;
+        // the physical plan of sql1 and sql2 is similar, both are 2-phase 
aggregate,
+        // so there is no need to do this rewrite
+        if (!distinctMultiColumns && !agg.getGroupByExpressions().isEmpty()) {
+            return false;
+        }
+        return true;
+    }
+
+    private static LogicalProject<Plan> constructProject(List<Expression> 
groupBy, Map<Alias, Alias> joinOutput,
+            List<Expression> outputJoinGroupBys, LogicalJoin<Plan, Plan> join) 
{
+        List<NamedExpression> projects = new ArrayList<>();
+        for (Map.Entry<Alias, Alias> entry : joinOutput.entrySet()) {
+            projects.add(new Alias(entry.getValue().getExprId(), 
entry.getKey().toSlot(), entry.getValue().getName()));
+        }
+        // outputJoinGroupBys.size() == agg.getGroupByExpressions().size()
+        for (int i = 0; i < groupBy.size(); ++i) {
+            Slot slot = (Slot) groupBy.get(i);
+            projects.add(new Alias(slot.getExprId(), 
outputJoinGroupBys.get(i), slot.getName()));
+        }
+        return new LogicalProject<>(projects, join);
+    }
+
+    private static LogicalJoin<Plan, Plan> 
constructJoin(List<LogicalAggregate<Plan>> newAggs,
+            List<Expression> groupBy) {
+        LogicalJoin<Plan, Plan> join;
+        if (groupBy.isEmpty()) {
+            join = new LogicalJoin<>(JoinType.CROSS_JOIN, newAggs.get(0), 
newAggs.get(1), null);
+            for (int j = 2; j < newAggs.size(); ++j) {
+                join = new LogicalJoin<>(JoinType.CROSS_JOIN, join, 
newAggs.get(j), null);
+            }
+        } else {
+            int len = groupBy.size();
+            List<Slot> leftSlots = newAggs.get(0).getOutput();
+            List<Slot> rightSlots = newAggs.get(1).getOutput();
+            List<Expression> hashConditions = new ArrayList<>();
+            for (int i = 0; i < len; ++i) {
+                hashConditions.add(new EqualTo(leftSlots.get(i), 
rightSlots.get(i)));
+            }
+            join = new LogicalJoin<>(JoinType.INNER_JOIN, hashConditions, 
newAggs.get(0), newAggs.get(1), null);
+            for (int j = 2; j < newAggs.size(); ++j) {
+                List<Slot> belowJoinSlots = join.left().getOutput();
+                List<Slot> belowRightSlots = newAggs.get(j).getOutput();
+                List<Expression> aboveHashConditions = new ArrayList<>();
+                for (int i = 0; i < len; ++i) {
+                    aboveHashConditions.add(new EqualTo(belowJoinSlots.get(i), 
belowRightSlots.get(i)));
+                }
+                join = new LogicalJoin<>(JoinType.INNER_JOIN, 
aboveHashConditions, join, newAggs.get(j), null);
+            }
+        }
+        return join;
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java
index 21e6ee1cba6..ba16b07ed5f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java
@@ -37,7 +37,7 @@ import java.util.List;
 
 /** count agg function. */
 public class Count extends NotNullableAggregateFunction
-        implements ExplicitlyCastableSignature, SupportWindowAnalytic, 
RollUpTrait {
+        implements ExplicitlyCastableSignature, SupportWindowAnalytic, 
RollUpTrait, SupportMultiDistinct {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
             // count(*)
@@ -162,4 +162,10 @@ public class Count extends NotNullableAggregateFunction
     public Expression resultForEmptyInput() {
         return new BigIntLiteral(0);
     }
+
+    @Override
+    public AggregateFunction convertToMultiDistinct() {
+        return new MultiDistinctCount(getArgument(0),
+                getArguments().subList(1, arity()).toArray(new Expression[0]));
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java
index 2505329b2fe..61cd525e651 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java
@@ -37,7 +37,7 @@ import java.util.List;
  * AggregateFunction 'group_concat'. This class is generated by 
GenerateFunction.
  */
 public class GroupConcat extends NullableAggregateFunction
-        implements ExplicitlyCastableSignature {
+        implements ExplicitlyCastableSignature, SupportMultiDistinct {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
             
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT),
@@ -133,6 +133,7 @@ public class GroupConcat extends NullableAggregateFunction
         return SIGNATURES;
     }
 
+    @Override
     public MultiDistinctGroupConcat convertToMultiDistinct() {
         Preconditions.checkArgument(distinct,
                 "can't convert to multi_distinct_group_concat because there is 
no distinct args");
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
index b5616dad15c..d1c862f5de4 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
@@ -49,7 +49,7 @@ import java.util.List;
  */
 public class Sum extends NullableAggregateFunction
         implements UnaryExpression, ExplicitlyCastableSignature, 
ComputePrecisionForSum, SupportWindowAnalytic,
-        RollUpTrait {
+        RollUpTrait, SupportMultiDistinct {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
             
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
@@ -81,6 +81,7 @@ public class Sum extends NullableAggregateFunction
         super("sum", distinct, alwaysNullable, arg);
     }
 
+    @Override
     public MultiDistinctSum convertToMultiDistinct() {
         Preconditions.checkArgument(distinct,
                 "can't convert to multi_distinct_sum because there is no 
distinct args");
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java
index 7c3873de01f..e02139420d2 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java
@@ -56,7 +56,7 @@ import java.util.List;
  */
 public class Sum0 extends NotNullableAggregateFunction
         implements UnaryExpression, ExplicitlyCastableSignature, 
ComputePrecisionForSum,
-        SupportWindowAnalytic, RollUpTrait {
+        SupportWindowAnalytic, RollUpTrait, SupportMultiDistinct {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
             
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
@@ -84,6 +84,7 @@ public class Sum0 extends NotNullableAggregateFunction
         super("sum0", distinct, arg);
     }
 
+    @Override
     public MultiDistinctSum0 convertToMultiDistinct() {
         Preconditions.checkArgument(distinct,
                 "can't convert to multi_distinct_sum because there is no 
distinct args");
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/SupportMultiDistinct.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/SupportMultiDistinct.java
new file mode 100644
index 00000000000..9feaf2025c4
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/SupportMultiDistinct.java
@@ -0,0 +1,25 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.agg;
+
+/** aggregate functions which have corresponding MultiDistinctXXX class,
+ * e.g. SUM,SUM0,COUNT,GROUP_CONCAT
+ * */
+public interface SupportMultiDistinct {
+    AggregateFunction convertToMultiDistinct();
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinctTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinctTest.java
new file mode 100644
index 00000000000..074135695a1
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinctTest.java
@@ -0,0 +1,191 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.util.MatchingUtils;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.utframe.TestWithFeService;
+
+import org.junit.jupiter.api.Test;
+
+public class SplitMultiDistinctTest extends TestWithFeService implements 
MemoPatternMatchSupported {
+    @Override
+    protected void runBeforeAll() throws Exception {
+        createDatabase("test");
+        createTable("create table test.test_distinct_multi(a int, b int, c 
int, d varchar(10), e date)"
+                + "distributed by hash(a) properties('replication_num'='1');");
+        connectContext.setDatabase("test");
+        
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
+    }
+
+    @Test
+    void multiCountWithoutGby() {
+        String sql = "select count(distinct b), count(distinct a) from 
test_distinct_multi";
+        PlanChecker.from(connectContext).checkExplain(sql, planner -> {
+            Plan plan = planner.getOptimizedPlan();
+            MatchingUtils.assertMatches(plan,
+                    physicalCTEAnchor(
+                            physicalCTEProducer(any()),
+                            physicalResultSink(
+                                    physicalProject(
+                                            physicalNestedLoopJoin(
+                                                    physicalHashAggregate(
+                                                            physicalDistribute(
+                                                                    
physicalHashAggregate(
+                                                                            
physicalHashAggregate(
+                                                                               
     physicalDistribute(
+                                                                               
             physicalHashAggregate(any())))))),
+                                                    physicalDistribute(
+                                                            
physicalHashAggregate(
+                                                                    
physicalDistribute(
+                                                                            
physicalHashAggregate(
+                                                                               
     physicalHashAggregate(
+                                                                               
             physicalDistribute(
+                                                                               
                     physicalHashAggregate(any())))))))
+                                            )
+                                    )
+                            )
+                    )
+            );
+        });
+    }
+
+    @Test
+    void multiSumWithoutGby() {
+        String sql = "select sum(distinct b), sum(distinct a) from 
test_distinct_multi";
+        PlanChecker.from(connectContext).checkExplain(sql, planner -> {
+            Plan plan = planner.getOptimizedPlan();
+            MatchingUtils.assertMatches(plan,
+                    physicalCTEAnchor(
+                            physicalCTEProducer(any()),
+                            physicalResultSink(
+                                    physicalProject(
+                                            physicalNestedLoopJoin(
+                                                    physicalHashAggregate(
+                                                            physicalDistribute(
+                                                                    
physicalHashAggregate(
+                                                                            
physicalHashAggregate(
+                                                                               
     physicalDistribute(
+                                                                               
             physicalHashAggregate(any())))))),
+                                                    physicalDistribute(
+                                                            
physicalHashAggregate(
+                                                                    
physicalDistribute(
+                                                                            
physicalHashAggregate(
+                                                                               
     physicalHashAggregate(
+                                                                               
             physicalDistribute(
+                                                                               
                     physicalHashAggregate(any())))))))
+                                            )
+                                    )
+                            )
+                    )
+            );
+        });
+    }
+
+    @Test
+    void sumCountWithoutGby() {
+        String sql = "select sum(distinct b), count(distinct a) from 
test_distinct_multi";
+        PlanChecker.from(connectContext).checkExplain(sql, planner -> {
+            Plan plan = planner.getOptimizedPlan();
+            MatchingUtils.assertMatches(plan,
+                    physicalCTEAnchor(
+                            physicalCTEProducer(any()),
+                            physicalResultSink(
+                                    physicalProject(
+                                            physicalNestedLoopJoin(
+                                                    physicalHashAggregate(
+                                                            physicalDistribute(
+                                                                    
physicalHashAggregate(
+                                                                            
physicalHashAggregate(
+                                                                               
     physicalDistribute(
+                                                                               
             physicalHashAggregate(any())))))),
+                                                    physicalDistribute(
+                                                            
physicalHashAggregate(
+                                                                    
physicalDistribute(
+                                                                            
physicalHashAggregate(
+                                                                               
     physicalHashAggregate(
+                                                                               
             physicalDistribute(
+                                                                               
                     physicalHashAggregate(any())))))))
+                                            )
+                                    )
+                            )
+                    )
+            );
+        });
+    }
+
+    @Test
+    void countMultiColumnsWithoutGby() {
+        String sql = "select count(distinct b,c), count(distinct a,b) from 
test_distinct_multi";
+        PlanChecker.from(connectContext).checkExplain(sql, planner -> {
+            Plan plan = planner.getOptimizedPlan();
+            MatchingUtils.assertMatches(plan,
+                    physicalCTEAnchor(
+                            physicalCTEProducer(any()),
+                            physicalResultSink(
+                                    physicalProject(
+                                            physicalNestedLoopJoin(
+                                                    physicalHashAggregate(
+                                                            
physicalHashAggregate(
+                                                                    
physicalDistribute(
+                                                                            
physicalHashAggregate(any())))),
+                                                    physicalDistribute(
+                                                            
physicalHashAggregate(
+                                                                    
physicalHashAggregate(
+                                                                            
physicalDistribute(
+                                                                               
     physicalHashAggregate(any())))))
+                                            )
+                                    )
+                            )
+                    )
+            );
+        });
+    }
+
+    @Test
+    void countMultiColumnsWithGby() {
+        String sql = "select count(distinct b,c), count(distinct a,b) from 
test_distinct_multi group by d";
+        PlanChecker.from(connectContext).checkExplain(sql, planner -> {
+            Plan plan = planner.getOptimizedPlan();
+            MatchingUtils.assertMatches(plan,
+                    physicalCTEAnchor(
+                            physicalCTEProducer(
+                                    any()),
+                            physicalResultSink(
+                                    physicalDistribute(
+                                            physicalProject(
+                                                    physicalHashJoin(
+                                                            
physicalHashAggregate(
+                                                                    
physicalHashAggregate(
+                                                                            
physicalDistribute(
+                                                                               
     physicalHashAggregate(any())))),
+                                                            
physicalHashAggregate(
+                                                                    
physicalHashAggregate(
+                                                                            
physicalDistribute(
+                                                                               
     physicalHashAggregate(any()))))
+                                                    )
+                                            )
+                                    )
+                            )
+                    )
+            );
+        });
+    }
+}
diff --git 
a/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out 
b/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out
new file mode 100644
index 00000000000..2a1dd6fd9d6
Binary files /dev/null and 
b/regression-test/data/nereids_rules_p0/distinct_split/disitinct_split.out 
differ
diff --git 
a/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy 
b/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy
new file mode 100644
index 00000000000..02812b269a3
--- /dev/null
+++ 
b/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy
@@ -0,0 +1,210 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("distinct_split") {
+    sql "set runtime_filter_mode = OFF"
+    sql "set disable_join_reorder=true"
+    sql "drop table if exists test_distinct_multi"
+    sql "create table test_distinct_multi(a int, b int, c int, d varchar(10), 
e date) distributed by hash(a) properties('replication_num'='1');"
+    sql "insert into test_distinct_multi 
values(1,2,3,'abc','2024-01-02'),(1,2,4,'abc','2024-01-03'),(2,2,4,'abcd','2024-01-02'),(1,2,3,'abcd','2024-01-04'),(1,2,4,'eee','2024-02-02'),(2,2,4,'abc','2024-01-02');"
+
+    // first bit 0 means distinct 1 col, 1 means distinct more than 1 col; 
second bit 0 means without group by, 1 means with group by;
+    // third bit 0 means there is 1 count(distinct) in projects, 1 means more 
than 1 count(distinct) in projects.
+
+    //000 distinct has 1 column, no group by, projection column has 1 count 
(distinct). four stages agg
+    qt_000_count """select count(distinct a) from test_distinct_multi"""
+
+    //001 distinct has 1 column, no group by, and multiple counts (distinct) 
in the projection column. The two-stage agg is slow for single point 
calculation in the second stage
+    qt_001_count """select count(distinct b), count(distinct a) from 
test_distinct_multi"""
+
+    //010 distinct has 1 column with group by, and the projection column has 1 
count (distinct). two-stage agg. The second stage follows group by hash
+    qt_010_count """select count(distinct a) from test_distinct_multi group by 
b order by 1"""
+    qt_010_count_same_column_with_groupby """select count(distinct a) from 
test_distinct_multi group by a order by 1"""
+
+    //011 distinct has one column with group by, and the projection column has 
multiple counts (distinct). two stages agg. The second stage follows group by 
hash
+    qt_011_count_same_column_with_groupby """select count(distinct 
a),count(distinct b)  from test_distinct_multi group by a  order by 1,2"""
+    qt_011_count_diff_column_with_groupby """select count(distinct 
a),count(distinct b)  from test_distinct_multi group by c order by 1,2"""
+    qt_011_count_diff_column_with_groupby_multi """select count(distinct 
a),count(distinct b)  from test_distinct_multi group by a,c order by 1,2"""
+    qt_011_count_diff_column_with_groupby_all """select count(distinct 
a),count(distinct b)  from test_distinct_multi group by a,b,c order by 1,2"""
+
+    //100 distinct columns with no group by, projection column with 1 count 
(distinct). Three stage agg, second stage gather
+    qt_100 """select count(distinct a,b) from test_distinct_multi"""
+
+    //101 distinct has multiple columns, no group by, and multiple counts 
(distinct) in the projection column (intercept). If the intercept is removed, 
it can be executed, but the result is incorrect
+    qt_101 """select count(distinct a,b), count(distinct a,c) from 
test_distinct_multi"""
+    qt_101_count_one_col_and_two_col """select count(distinct a,b), 
count(distinct c) from test_distinct_multi"""
+    qt_101_count_one_col_and_two_col """select count(distinct a,b), 
count(distinct a) from test_distinct_multi"""
+
+    //110 distinct has multiple columns, including group by, and the 
projection column has one count (distinct). three-stage agg. The second stage 
follows group by hash
+    qt_110_count_diff_column_with_groupby """select count(distinct a,b) from 
test_distinct_multi group by c  order by 1"""
+    qt_110_count_same_column_with_groupby1 """select count(distinct a,b) from 
test_distinct_multi group by a  order by 1"""
+    qt_110_count_same_column_with_groupby2 """select count(distinct a,b) from 
test_distinct_multi group by a,b  order by 1"""
+
+    //111 distinct has multiple columns, including group by, and the 
projection column has multiple counts (distinct) (intercept). If the intercept 
is removed, it can be executed, but the result is incorrect
+    qt_111_count_same_column_with_groupby1 """select count(distinct a,b), 
count(distinct a,c) from test_distinct_multi group by c  order by 1,2"""
+    qt_111_count_same_column_with_groupby2 """select count(distinct a,b), 
count(distinct c) from test_distinct_multi group by a,c order by 1,2"""
+    qt_111_count_diff_column_with_groupby """select count(distinct a,b), 
count(distinct a) from test_distinct_multi group by c order by 1,2"""
+
+    // testing other functions
+    qt_000_count_other_func """select count(distinct a), max(b),sum(c),min(a) 
from test_distinct_multi"""
+    qt_001_count_other_func """select count(distinct b), count(distinct a), 
max(b),sum(c),min(a)  from test_distinct_multi"""
+    qt_010_count_other_func """select count(distinct a), 
max(b),sum(c),min(a),b from test_distinct_multi group by b order by 1,2,3,4,5"""
+    qt_011_count_other_func """select count(distinct a), count(distinct 
b),max(b),sum(c),min(a),a  from test_distinct_multi group by a  order by 
1,2,3,4,5,6"""
+    qt_100_count_other_func """select count(distinct a,b), count(distinct 
b),max(b),sum(c),min(a) from test_distinct_multi"""
+    qt_101_count_other_func """select count(distinct a,b), count(distinct 
a,c),max(b),sum(c),min(a) from test_distinct_multi"""
+    qt_110_count_other_func """select count(distinct 
a,b),max(b),sum(c),min(a),c from test_distinct_multi group by c  order by 
1,2,3,4,5"""
+    qt_111_count_other_func """select count(distinct a,b), count(distinct 
a,c),max(b),sum(c),min(a),c  from test_distinct_multi group by c  order by 
1,2,3,4,5,6"""
+
+    // multi distinct three four five
+    qt_001_three """select count(distinct b), count(distinct a), 
count(distinct c) from test_distinct_multi"""
+    qt_001_four  """select count(distinct b), count(distinct a), 
count(distinct c), count(distinct d) from test_distinct_multi"""
+    qt_001_five  """select count(distinct b), count(distinct a), 
count(distinct c), count(distinct d), count(distinct e) from 
test_distinct_multi"""
+
+    qt_011_three """select count(distinct b), count(distinct a), 
count(distinct c) from test_distinct_multi group by d order by 1,2,3"""
+    qt_011_four """select count(distinct b), count(distinct a), count(distinct 
c), count(distinct d) from test_distinct_multi group by d order by 1,2,3,4"""
+    qt_011_five """select count(distinct b), count(distinct a), count(distinct 
c), count(distinct d), count(distinct e) from test_distinct_multi group by d 
order by 1,2,3,4,5"""
+    qt_011_three_gby_multi """select count(distinct b), count(distinct a), 
count(distinct c) from test_distinct_multi group by d,a order by 1,2,3"""
+    qt_011_four_gby_multi """select count(distinct b), count(distinct a), 
count(distinct c), count(distinct d) from test_distinct_multi group by d,c,a 
order by 1,2,3,4"""
+    qt_011_five_gby_multi """select count(distinct b), count(distinct a), 
count(distinct c), count(distinct d), count(distinct e) from 
test_distinct_multi group by d,b,a order by 1,2,3,4,5"""
+
+    qt_101_three """select count(distinct a,b), count(distinct a,c) , 
count(distinct a) from test_distinct_multi"""
+    qt_101_four """select count(distinct a,b), count(distinct a,c) , 
count(distinct a), count(distinct c) from test_distinct_multi"""
+    qt_101_five """select count(distinct a,b), count(distinct a,c) , 
count(distinct a,d), count(distinct c) , count(distinct a,b,c,d) from 
test_distinct_multi"""
+
+    qt_111_three """select count(distinct a,b), count(distinct a,c) , 
count(distinct a) from test_distinct_multi group by c  order by 1,2,3"""
+    qt_111_four """select count(distinct a,b), count(distinct a,c) , 
count(distinct a), count(distinct c) from test_distinct_multi group by e order 
by 1,2,3,4"""
+    qt_111_five """select count(distinct a,b), count(distinct a,c) , 
count(distinct a,d), count(distinct c) , count(distinct a,b,c,d) from 
test_distinct_multi group by e order by 1,2,3,4,5"""
+    qt_111_three_gby_multi """select count(distinct a,b), count(distinct a,c) 
, count(distinct a) from test_distinct_multi group by c,a  order by 1,2,3"""
+    qt_111_four_gby_multi """select count(distinct a,b), count(distinct a,c) , 
count(distinct a), count(distinct c) from test_distinct_multi group by e,a,b 
order by 1,2,3,4"""
+    qt_111_five_gby_multi """select count(distinct a,b), count(distinct a,c) , 
count(distinct a,d), count(distinct c) , count(distinct a,b,c,d) from 
test_distinct_multi group by e,a,b,c,d order by 1,2,3,4,5"""
+
+    // sum has two dimensions: 1. Is there one or more projection columns (0 
for one, 1 for more) 2. Is there a group by (0 for none, 1 for yes)
+    qt_00_sum """select sum(distinct b) from test_distinct_multi"""
+    qt_10_sum """select sum(distinct b), sum(distinct a) from 
test_distinct_multi"""
+    qt_01_sum """select sum(distinct b) from test_distinct_multi group by a 
order by 1"""
+    qt_11_sum """select sum(distinct b), sum(distinct a) from 
test_distinct_multi group by a order by 1,2"""
+
+    // avg has two dimensions: 1. Is there one or more projection columns (0 
for one, 1 for more) 2. Is there a group by (0 for no, 1 for yes)
+    qt_00_avg """select avg(distinct b) from test_distinct_multi"""
+    qt_10_avg """select avg(distinct b), avg(distinct a) from 
test_distinct_multi"""
+    qt_01_avg """select avg(distinct b) from test_distinct_multi group by a 
order by 1"""
+    qt_11_avg """select avg(distinct b), avg(distinct a) from 
test_distinct_multi group by a order by 1,2"""
+
+    //group_concat
+    sql """select group_concat(distinct d order by d)  from 
test_distinct_multi"""
+    sql """select group_concat(distinct d order by d), group_concat(distinct 
cast(a as string) order by cast(a as string))  from test_distinct_multi"""
+    sql """select group_concat(distinct d order by d)  from 
test_distinct_multi group by a order by 1"""
+    sql """select group_concat(distinct d order by d), group_concat(distinct 
cast(a as string) order by cast(a as string))  from test_distinct_multi  group 
by a order by 1,2"""
+
+    // mixed distinct function
+    qt_count_sum_avg_no_gby "select sum(distinct b), count(distinct a), 
avg(distinct c) from test_distinct_multi"
+    qt_count_multi_sum_avg_no_gby "select sum(distinct b), count(distinct 
a,d), avg(distinct c) from test_distinct_multi"
+    qt_count_sum_avg_with_gby "select sum(distinct b), count(distinct a), 
avg(distinct c) from test_distinct_multi group by b,a order by 1,2,3"
+    qt_count_multi_sum_avg_with_gby "select sum(distinct b), count(distinct 
a,d), avg(distinct c) from test_distinct_multi  group by a,b order by 1,2,3"
+
+    // There is a reference query in the upper layer
+    qt_multi_sum_has_upper """select c1+ c2 from (select sum(distinct b) c1, 
sum(distinct a) c2 from test_distinct_multi) t"""
+    qt_000_count_has_upper """select abs(c1) from (select count(distinct a) c1 
from test_distinct_multi) t"""
+    qt_010_count_has_upper """select c1+100 from (select count(distinct a) c1 
from test_distinct_multi group by b) t order by 1"""
+    qt_011_count_diff_column_with_groupby_all_has_upper """select max(c2), 
max(c1) from (select count(distinct a) c1,count(distinct b) c2 from 
test_distinct_multi group by a,b,c) t"""
+    qt_100_has_upper """select c1+1 from (select count(distinct a,b) c1 from 
test_distinct_multi) t where c1>0"""
+    qt_101_has_upper """select c1+c2+100 from (select count(distinct a,b) c1, 
count(distinct a,c) c2 from test_distinct_multi) t"""
+    qt_111_count_same_column_with_groupby1_has_upper """select max(c1), 
min(c2) from (select count(distinct a,b) c1, count(distinct a,c) c2 from 
test_distinct_multi group by c) t"""
+    qt_010_count_sum_other_func_has_upper """select sum(c0),max(c1+c2), 
min(c2+c3+c4),max(b)  from (select sum(distinct b) c0,count(distinct a) c1, 
max(b) c2,sum(c) c3,min(a) c4,b from test_distinct_multi group by b) t"""
+    qt_010_count_other_func_has_upper"""select sum(c0), max(c1+c2), 
min(c2+c3+c4),max(b)  from (select count(distinct b) c0,count(distinct a) c1, 
max(b) c2,sum(c) c3,min(a) c4,b from test_distinct_multi group by b) t"""
+
+    // In cte or in nested cte.
+    qt_cte_producer """with t1 as (select a,b from test_distinct_multi)
+    select count(distinct t.a), count(distinct tt.b) from t1 t cross join t1 
tt;"""
+    qt_cte_consumer """with t1 as (select count(distinct a), count(distinct b) 
from test_distinct_multi)
+    select * from t1 t cross join t1 tt;"""
+    qt_cte_multi_producer """
+    with t1 as (select count(distinct a), count(distinct b) from 
test_distinct_multi),
+    t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi),
+    t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi)
+    select * from t1,t2,t3;
+    """
+    qt_multi_cte_nest """
+    with t1 as (select count(distinct a), count(distinct b) from 
test_distinct_multi),
+    t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi),
+    t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi)
+    select * from t1,t2,t3, (with t1 as (select count(distinct a), 
count(distinct b) from test_distinct_multi),
+    t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi),
+    t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi)
+    select * from t1,t2,t3) tmp;
+    """
+    qt_multi_cte_nest2 """
+    with t1 as (with t1 as (select count(distinct a), count(distinct b) from 
test_distinct_multi),
+    t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi),
+    t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi)
+    select * from t1,t2,t3, (with t1 as (select count(distinct a), 
count(distinct b) from test_distinct_multi),
+    t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi),
+    t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi)
+    select * from t1,t2,t3) tmp)
+    select * from t1,t1,(with t1 as (select count(distinct a), count(distinct 
b) from test_distinct_multi),
+    t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi),
+    t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi)
+    select * from t1,t2,t3, (with t1 as (select count(distinct a), 
count(distinct b) from test_distinct_multi),
+    t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi),
+    t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi)
+    select * from t1,t2,t3) tmp) t
+    """
+    qt_cte_consumer_count_multi_column_with_group_by """with t1 as (select 
count(distinct a,b), count(distinct b,c) from test_distinct_multi group by d)
+    select * from t1 t cross join t1 tt order by 1,2,3,4;"""
+    qt_cte_consumer_count_multi_column_without_group_by """with t1 as (select 
sum(distinct a), count(distinct b,c) from test_distinct_multi)
+    select * from t1 t cross join t1 tt;"""
+    qt_cte_multi_producer_multi_column """
+    with t1 as (select count(distinct a), count(distinct b,d) from 
test_distinct_multi group by c),
+    t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi 
group by c),
+    t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi)
+    select * from t1,t2,t3 order by 1,2,3,4,5,6;
+    """
+    qt_cte_multi_nested """
+    with tmp as (with t1 as (select count(distinct a), count(distinct b,d) 
from test_distinct_multi group by c),
+    t2 as (select sum(distinct a), sum(distinct b) from test_distinct_multi 
group by c),
+    t3 as (select sum(distinct a), count(distinct b) from test_distinct_multi)
+    select * from t1,t2,t3)
+    select * from tmp, (select sum(distinct a), count(distinct b,c) from 
test_distinct_multi) t, (select sum(distinct a), count(distinct b,c) from 
test_distinct_multi group by d) tt order by 1,2,3,4,5,6,7,8,9,10
+    """
+
+    // multi aggregate
+    qt_2_agg_count_distinct """select count(distinct c1) c3, count(distinct 
c2) c4 from (select count(distinct a,b) c1, count(distinct a,c) c2 from 
test_distinct_multi group by c) t"""
+    qt_3_agg_count_distinct """select count(distinct c3), count(distinct c4) 
from (select count(distinct c1) c3, count(distinct c2) c4 from (select 
count(distinct a,b) c1, count(distinct a,c) c2 from test_distinct_multi group 
by c) t) tt"""
+
+    // shape
+    sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"
+    qt_multi_count_without_gby """explain shape plan select count(distinct b), 
count(distinct a) from test_distinct_multi"""
+    qt_multi_sum_without_gby """explain shape plan select sum(distinct b), 
sum(distinct a) from test_distinct_multi"""
+    qt_sum_count_without_gby """explain shape plan select sum(distinct b), 
count(distinct a) from test_distinct_multi"""
+    qt_multi_count_mulitcols_without_gby """explain shape plan select 
count(distinct b,c), count(distinct a,b) from test_distinct_multi"""
+    qt_multi_count_mulitcols_with_gby """explain shape plan select 
count(distinct b,c), count(distinct a,b) from test_distinct_multi group by d"""
+    qt_three_count_mulitcols_without_gby """explain shape plan select 
count(distinct b,c), count(distinct a,b), count(distinct a,b,c) from 
test_distinct_multi"""
+    qt_four_count_mulitcols_with_gby """explain shape plan select 
count(distinct b,c), count(distinct a,b),count(distinct b,c,d), count(distinct 
a,b,c) from test_distinct_multi group by d"""
+    qt_has_other_func "explain shape plan select count(distinct b), 
count(distinct a), max(b),sum(c),min(a)  from test_distinct_multi"
+    qt_2_agg """explain shape plan select max(c1), min(c2) from (select 
count(distinct a,b) c1, count(distinct a,c) c2 from test_distinct_multi group 
by c) t"""
+
+    // should not rewrite
+    qt_multi_count_with_gby """explain shape plan select count(distinct b), 
count(distinct a) from test_distinct_multi group by c"""
+    qt_multi_sum_with_gby """explain shape plan select sum(distinct b), 
sum(distinct a) from test_distinct_multi group by c"""
+    qt_sum_count_with_gby """explain shape plan select sum(distinct b), 
count(distinct a) from test_distinct_multi group by a"""
+    qt_has_grouping """explain shape plan select count(distinct b), 
count(distinct a) from test_distinct_multi group by grouping sets((a,b),(c));"""
+    test {
+        sql """select count(distinct a,b), count(distinct a) from 
test_distinct_multi
+        group by grouping sets((a,b),(c));"""
+        exception "The query contains multi count distinct or sum distinct, 
each can't have multi columns"
+    }
+}
\ No newline at end of file
diff --git 
a/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy
 
b/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy
index 356b96267a8..f5545bc41b2 100644
--- 
a/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy
+++ 
b/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy
@@ -360,7 +360,7 @@ suite("aggregate_without_roll_up") {
             "from orders " +
             "where O_ORDERDATE < '2023-12-30' and O_ORDERDATE > '2023-12-01'"
     order_qt_query3_0_before "${query3_0}"
-    async_mv_rewrite_success(db, mv3_0, query3_0, "mv3_0")
+    async_mv_rewrite_fail(db, mv3_0, query3_0, "mv3_0")
     order_qt_query3_0_after "${query3_0}"
     sql """ DROP MATERIALIZED VIEW IF EXISTS mv3_0"""
 
@@ -883,7 +883,7 @@ suite("aggregate_without_roll_up") {
             "on lineitem.L_ORDERKEY = orders.O_ORDERKEY " +
             "where orders.O_ORDERDATE < '2023-12-30' and orders.O_ORDERDATE > 
'2023-12-01' "
     order_qt_query20_0_before "${query20_0}"
-    async_mv_rewrite_success(db, mv20_0, query20_0, "mv20_0")
+    async_mv_rewrite_fail(db, mv20_0, query20_0, "mv20_0")
     order_qt_query20_0_after "${query20_0}"
     sql """ DROP MATERIALIZED VIEW IF EXISTS mv20_0"""
 
diff --git 
a/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_1.groovy 
b/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_1.groovy
index 3aed3b0f9e2..59cff69ee89 100644
--- a/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_1.groovy
+++ b/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_1.groovy
@@ -410,7 +410,7 @@ suite("partition_mv_rewrite_dimension_1") {
         count(*) 
         from orders_1
         """
-    mv_rewrite_success(agg_sql_1, agg_mv_name_1)
+    mv_rewrite_fail(agg_sql_1, agg_mv_name_1)
     compare_res(agg_sql_1 + " order by 1,2,3,4,5,6")
     sql """DROP MATERIALIZED VIEW IF EXISTS ${agg_mv_name_1};"""
 
diff --git 
a/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_3.groovy 
b/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_3.groovy
index c7ee359cdef..a50d77bf3cc 100644
--- a/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_3.groovy
+++ b/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_3.groovy
@@ -145,7 +145,7 @@ suite("partition_mv_rewrite_dimension_2_3") {
             count(*)
             from orders_2_3
             left join lineitem_2_3 on lineitem_2_3.l_orderkey = 
orders_2_3.o_orderkey"""
-    mv_rewrite_success(sql_stmt_1, mv_name_1)
+    mv_rewrite_fail(sql_stmt_1, mv_name_1)
     compare_res(sql_stmt_1 + " order by 1,2,3,4,5,6")
     sql """DROP MATERIALIZED VIEW IF EXISTS ${mv_name_1};"""
 
diff --git 
a/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_4.groovy 
b/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_4.groovy
index e59b2771dd4..05c57974389 100644
--- a/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_4.groovy
+++ b/regression-test/suites/nereids_rules_p0/mv/dimension/dimension_2_4.groovy
@@ -577,7 +577,7 @@ suite("partition_mv_rewrite_dimension_2_4") {
             count(distinct case when O_SHIPPRIORITY > 2 and o_orderkey IN (2) 
then o_custkey else null end) as cnt_2 
             from orders_2_4  
             where o_orderkey > (-3) + 5 """
-    mv_rewrite_success(sql_stmt_13, mv_name_13)
+    mv_rewrite_fail(sql_stmt_13, mv_name_13)
     compare_res(sql_stmt_13 + " order by 1")
     sql """DROP MATERIALIZED VIEW IF EXISTS ${mv_name_13};"""
 
diff --git 
a/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy 
b/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy
index 1b546db0ff8..aeb39fb275a 100644
--- a/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy
+++ b/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy
@@ -149,12 +149,6 @@ suite("aggregate_strategies") {
             from $tableName
         )a
         group by c"""
-
-
-        test {
-            sql "select count(distinct id, name), count(distinct id) from 
$tableName"
-            exception "The query contains multi count distinct or sum 
distinct, each can't have multi columns"
-        }
     }
 
     test_aggregate_strategies('test_bucket1_table', 1)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to