This is an automated email from the ASF dual-hosted git repository.

huajianlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 593a916ae6 [feature](nereids) split AggregateDisassemble into two 
rules (#14611)
593a916ae6 is described below

commit 593a916ae64e0ffd8884651ce4ec8279dd755658
Author: yinzhijian <[email protected]>
AuthorDate: Wed Nov 30 14:02:42 2022 +0800

    [feature](nereids) split AggregateDisassemble into two rules (#14611)
    
    # Proposed changes
    
    Issue Number: close #14280
    
    ## Problem summary
    
    The AggregateDisassemble rule is refactored and split into two rules, which 
are not dependent on each other.
    1. AggregateDisassemble splits the agg into two phases: Local, Global.
    1.1. For count function, the implementation is as 
follows:distinct_multi_count(update)+ distinct_multi_count(merge)
    
    2. DistinctAggregateDisassemble splits the agg into 4 stages: Local, 
Global, Distinct Local, Distinct GLobal.
    2.1. For count function, the implementation is as 
follows:distinct_multi_count(update)+ distinct_multi_count(merge)+sum(update)+ 
sum(merge)
---
 .../glue/translator/ExpressionTranslator.java      |  12 +-
 .../glue/translator/PhysicalPlanTranslator.java    |  44 ++---
 .../properties/ChildOutputPropertyDeriver.java     |   3 +-
 .../nereids/properties/RequestPropertyDeriver.java |  32 +++-
 .../org/apache/doris/nereids/rules/RuleSet.java    |   2 +
 .../org/apache/doris/nereids/rules/RuleType.java   |   1 +
 .../doris/nereids/rules/analysis/BindFunction.java |   3 +-
 .../rules/rewrite/AggregateDisassemble.java        | 110 ++----------
 .../rewrite/DistinctAggregateDisassemble.java      | 190 +++++++++++++++++++++
 .../functions/agg/AggregateFunction.java           |  10 +-
 .../expressions/functions/agg/AggregateParam.java  |  43 +++--
 .../trees/plans/logical/LogicalAggregate.java      |  33 ++++
 .../rewrite/logical/AggregateDisassembleTest.java  | 182 +++++++++++---------
 .../trees/expressions/ExpressionEqualsTest.java    |   6 +-
 14 files changed, 435 insertions(+), 236 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
index cf9532253b..1863da7c7f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
@@ -68,6 +68,7 @@ import 
org.apache.doris.nereids.trees.expressions.functions.agg.Count;
 import 
org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
+import org.apache.doris.nereids.trees.plans.AggPhase;
 import org.apache.doris.nereids.types.coercion.AbstractDataType;
 import org.apache.doris.thrift.TFunctionBinaryType;
 
@@ -289,8 +290,17 @@ public class ExpressionTranslator extends 
DefaultExpressionVisitor<Expr, PlanTra
                 : NullableMode.ALWAYS_NOT_NULLABLE;
 
         boolean isAnalyticFunction = false;
+        String functionName = function.isDistinct() ? "MULTI_DISTINCT_" + 
function.getName() : function.getName();
+        if (function.getAggregateParam().aggPhase == AggPhase.DISTINCT_LOCAL
+                || function.getAggregateParam().aggPhase == 
AggPhase.DISTINCT_GLOBAL) {
+            if (function.getName().equalsIgnoreCase("count")) {
+                functionName = "SUM";
+            } else {
+                functionName = function.getName();
+            }
+        }
         org.apache.doris.catalog.AggregateFunction catalogFunction = new 
org.apache.doris.catalog.AggregateFunction(
-                new FunctionName(function.getName()), argTypes,
+                new FunctionName(functionName), argTypes,
                 function.getDataType().toCatalogDataType(),
                 function.getIntermediateTypes().toCatalogDataType(),
                 function.hasVarArguments(),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index 235f0265b6..417b0b049e 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -159,7 +159,6 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
 
     /**
      * Translate Agg.
-     * todo: support DISTINCT
      */
     @Override
     public PlanFragment visitPhysicalAggregate(
@@ -201,44 +200,23 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         List<Expr> execPartitionExpressions = partitionExpressionList.stream()
                 .map(e -> ExpressionTranslator.translate(e, 
context)).collect(Collectors.toList());
         DataPartition mergePartition = DataPartition.UNPARTITIONED;
-        if (CollectionUtils.isNotEmpty(execPartitionExpressions)) {
+        if (CollectionUtils.isNotEmpty(execPartitionExpressions)
+                && aggregate.getAggPhase() != AggPhase.DISTINCT_GLOBAL) {
             mergePartition = 
DataPartition.hashPartitioned(execPartitionExpressions);
         }
 
         // 3. generate output tuple
         List<Slot> slotList = Lists.newArrayList();
         TupleDescriptor outputTupleDesc;
-        if (aggregate.getAggPhase() == AggPhase.LOCAL
-                || (aggregate.getAggPhase() == AggPhase.GLOBAL && 
aggregate.isFinalPhase())
-                || aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
-            slotList.addAll(groupSlotList);
-            slotList.addAll(aggFunctionOutput);
-            outputTupleDesc = generateTupleDesc(slotList, null, context);
-        } else {
-            // In the distinct agg scenario, global shares local's desc
-            AggregationNode localAggNode;
-            if (inputPlanFragment.getPlanRoot() instanceof ExchangeNode) {
-                localAggNode = (AggregationNode) 
inputPlanFragment.getPlanRoot().getChild(0);
-            } else {
-                // If the group by expr hits the partition key, there may be 
no exchange node
-                localAggNode = (AggregationNode) 
inputPlanFragment.getPlanRoot();
-            }
-            outputTupleDesc = localAggNode.getAggInfo().getOutputTupleDesc();
-        }
+        slotList.addAll(groupSlotList);
+        slotList.addAll(aggFunctionOutput);
+        outputTupleDesc = generateTupleDesc(slotList, null, context);
 
-        // TODO: move setMergeForNereids to 
ExpressionTranslator.visitAggregateFunction
-        if (aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
-            for (FunctionCallExpr execAggregateFunction : 
execAggregateFunctions) {
-                if (!execAggregateFunction.isDistinct()) {
-                    execAggregateFunction.setMergeForNereids(true);
-                }
-            }
-        }
         AggregateInfo aggInfo = AggregateInfo.create(execGroupingExpressions, 
execAggregateFunctions, outputTupleDesc,
                 outputTupleDesc, aggregate.getAggPhase().toExec());
         AggregationNode aggregationNode = new 
AggregationNode(context.nextPlanNodeId(),
                 inputPlanFragment.getPlanRoot(), aggInfo);
-        if (!aggregate.isFinalPhase()) {
+        if (!aggregate.getAggPhase().isGlobal() && !aggregate.isFinalPhase()) {
             aggregationNode.unsetNeedsFinalize();
         }
         PlanFragment currentFragment = inputPlanFragment;
@@ -247,8 +225,12 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
                 
aggregationNode.setUseStreamingPreagg(aggregate.isUsingStream());
                 aggregationNode.setIntermediateTuple();
                 break;
-            case GLOBAL:
             case DISTINCT_LOCAL:
+            case GLOBAL:
+            case DISTINCT_GLOBAL:
+                if (aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
+                    aggregationNode.setIntermediateTuple();
+                }
                 if (currentFragment.getPlanRoot() instanceof ExchangeNode) {
                     ExchangeNode exchangeNode = (ExchangeNode) 
currentFragment.getPlanRoot();
                     currentFragment = new 
PlanFragment(context.nextFragmentId(), exchangeNode, mergePartition);
@@ -257,7 +239,9 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
                     inputPlanFragment.setDestination(exchangeNode);
                     context.addPlanFragment(currentFragment);
                 }
-                currentFragment.updateDataPartition(mergePartition);
+                if (aggregate.getAggPhase() != AggPhase.DISTINCT_LOCAL) {
+                    currentFragment.updateDataPartition(mergePartition);
+                }
                 break;
             default:
                 throw new RuntimeException("Unsupported yet");
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
index b4772dad73..86cf47386e 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
@@ -79,18 +79,17 @@ public class ChildOutputPropertyDeriver extends 
PlanVisitor<PhysicalProperties,
     public PhysicalProperties visitPhysicalAggregate(PhysicalAggregate<? 
extends Plan> agg, PlanContext context) {
         Preconditions.checkState(childrenOutputProperties.size() == 1);
         PhysicalProperties childOutputProperty = 
childrenOutputProperties.get(0);
-        // TODO: add distinct phase output properties
         switch (agg.getAggPhase()) {
             case LOCAL:
             case GLOBAL:
             case DISTINCT_LOCAL:
+            case DISTINCT_GLOBAL:
                 DistributionSpec childSpec = 
childOutputProperty.getDistributionSpec();
                 if (childSpec instanceof DistributionSpecHash) {
                     DistributionSpecHash distributionSpecHash = 
(DistributionSpecHash) childSpec;
                     return new 
PhysicalProperties(distributionSpecHash.withShuffleType(ShuffleType.BUCKETED));
                 }
                 return new 
PhysicalProperties(childOutputProperty.getDistributionSpec());
-            case DISTINCT_GLOBAL:
             default:
                 throw new RuntimeException("Could not derive output properties 
for agg phase: " + agg.getAggPhase());
         }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
index 96379b6339..8b0fc1c95d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
@@ -24,7 +24,9 @@ import org.apache.doris.nereids.memo.GroupExpression;
 import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import org.apache.doris.nereids.trees.plans.AggPhase;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
@@ -36,9 +38,11 @@ import 
org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort;
 import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
 import org.apache.doris.nereids.util.JoinUtils;
 
+import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
 
 import java.util.List;
+import java.util.Set;
 import java.util.stream.Collectors;
 
 /**
@@ -94,10 +98,18 @@ public class RequestPropertyDeriver extends 
PlanVisitor<Void, PlanContext> {
         }
         // 2. second phase agg, need to return shuffle with partition key
         List<Expression> partitionExpressions = agg.getPartitionExpressions();
-        if (partitionExpressions.isEmpty()) {
+        if (partitionExpressions.isEmpty() && agg.getAggPhase() != 
AggPhase.DISTINCT_LOCAL) {
             addRequestPropertyToChildren(PhysicalProperties.GATHER);
             return null;
         }
+        if (agg.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
+            // use slots in distinct agg as shuffle slots
+            List<ExprId> shuffleSlots = 
extractFromDistinctFunction(agg.getOutputExpressions());
+            Preconditions.checkState(!shuffleSlots.isEmpty());
+            addRequestPropertyToChildren(
+                    PhysicalProperties.createHash(new 
DistributionSpecHash(shuffleSlots, ShuffleType.AGGREGATE)));
+            return null;
+        }
         // TODO: when parent is a join node,
         //    use requestPropertyFromParent to keep column order as join to 
avoid shuffle again.
         if 
(partitionExpressions.stream().allMatch(SlotReference.class::isInstance)) {
@@ -170,5 +182,23 @@ public class RequestPropertyDeriver extends 
PlanVisitor<Void, PlanContext> {
     private void addRequestPropertyToChildren(PhysicalProperties... 
physicalProperties) {
         requestPropertyToChildren.add(Lists.newArrayList(physicalProperties));
     }
+
+    private List<ExprId> extractFromDistinctFunction(List<NamedExpression> 
outputExpression) {
+        List<ExprId> exprIds = Lists.newArrayList();
+        for (NamedExpression originOutputExpr : outputExpression) {
+            Set<AggregateFunction> aggregateFunctions
+                    = 
originOutputExpr.collect(AggregateFunction.class::isInstance);
+            for (AggregateFunction aggregateFunction : aggregateFunctions) {
+                if (aggregateFunction.isDistinct()) {
+                    for (Expression expr : aggregateFunction.children()) {
+                        Preconditions.checkState(expr instanceof 
SlotReference, "normalize aggregate failed to"
+                                + " normalize aggregate function " + 
aggregateFunction.toSql());
+                        exprIds.add(((SlotReference) expr).getExprId());
+                    }
+                }
+            }
+        }
+        return exprIds;
+    }
 }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
index 87ed6b1bb1..d3212f9214 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
@@ -41,6 +41,7 @@ import 
org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickS
 import 
org.apache.doris.nereids.rules.implementation.LogicalTVFRelationToPhysicalTVFRelation;
 import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
 import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
+import org.apache.doris.nereids.rules.rewrite.DistinctAggregateDisassemble;
 import org.apache.doris.nereids.rules.rewrite.logical.EliminateOuterJoin;
 import org.apache.doris.nereids.rules.rewrite.logical.MergeFilters;
 import org.apache.doris.nereids.rules.rewrite.logical.MergeLimits;
@@ -73,6 +74,7 @@ public class RuleSet {
             .add(SemiJoinSemiJoinTranspose.INSTANCE)
             .add(SemiJoinSemiJoinTransposeProject.INSTANCE)
             .add(new AggregateDisassemble())
+            .add(new DistinctAggregateDisassemble())
             .add(new PushdownFilterThroughProject())
             .add(new MergeProjects())
             .build();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 32bb60b8b3..ec1698a068 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -77,6 +77,7 @@ public enum RuleType {
     NORMALIZE_AGGREGATE(RuleTypeClass.REWRITE),
     NORMALIZE_REPEAT(RuleTypeClass.REWRITE),
     AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
+    DISTINCT_AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
     COLUMN_PRUNE_PROJECTION(RuleTypeClass.REWRITE),
     ELIMINATE_UNNECESSARY_PROJECT(RuleTypeClass.REWRITE),
     LOGICAL_SUB_QUERY_ALIAS_TO_LOGICAL_PROJECT(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
index 78d20462da..ac0c1238e0 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
@@ -37,6 +37,7 @@ import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
 import 
org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction;
 import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.trees.plans.AggPhase;
 import org.apache.doris.nereids.trees.plans.GroupPlan;
 import org.apache.doris.nereids.trees.plans.RelationId;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
@@ -177,7 +178,7 @@ public class BindFunction implements AnalysisRuleFactory {
                 }
                 if (arguments.size() == 1) {
                     AggregateParam aggregateParam = new AggregateParam(
-                            unboundFunction.isDistinct(), true, false);
+                            unboundFunction.isDistinct(), true, 
AggPhase.LOCAL, false);
                     return new Count(aggregateParam, 
unboundFunction.getArguments().get(0));
                 }
             }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
index a09066c267..4d0b55ebe3 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
@@ -17,7 +17,6 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
-import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.expressions.Alias;
@@ -37,7 +36,6 @@ import com.google.common.collect.Maps;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -53,84 +51,23 @@ import java.util.stream.Collectors;
  *   Aggregate(phase: [GLOBAL], outputExpr: [Alias(b) #1, Alias(SUM(a) + 1) 
#2], groupByExpr: [b])
  *   +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as 
b], groupByExpr: [k + 1])
  *       +-- childPlan
- *
- * Distinct Agg With Group By Processing:
- * If we have a query: SELECT count(distinct v1 * v2) + 1 FROM t GROUP BY k + 1
- * the initial plan is:
- *   Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1, 
Alias(COUNT(distinct v1 * v2) + 1) #2]
- *                            , groupByExpr: [k + 1])
- *   +-- childPlan
- * we should rewrite to:
- *   Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [Alias(b) #1, 
Alias(COUNT(distinct a) + 1) #2], groupByExpr: [b])
- *   +-- Aggregate(phase: [GLOBAL], outputExpr: [b, a], groupByExpr: [b, a])
- *       +-- Aggregate(phase: [LOCAL], outputExpr: [(k + 1) as b, (v1 * v2) as 
a], groupByExpr: [k + 1, a])
- *           +-- childPlan
  * </pre>
  *
  * TODO:
- *     1. use different class represent different phase aggregate
- *     2. if instance count is 1, shouldn't disassemble the agg plan
+ *     1. if instance count is 1, shouldn't disassemble the agg plan
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
     @Override
     public Rule build() {
         return logicalAggregate()
-                .whenNot(LogicalAggregate::isDisassembled)
-                .then(aggregate -> {
-                    // used in secondDisassemble to transform local 
expressions into global
-                    Map<Expression, Expression> globalOutputSMap = 
Maps.newHashMap();
-                    Pair<LogicalAggregate<LogicalAggregate<GroupPlan>>, 
Boolean> result
-                            = disassembleAggregateFunction(aggregate, 
globalOutputSMap);
-                    LogicalAggregate<LogicalAggregate<GroupPlan>> newPlan = 
result.first;
-                    boolean hasDistinct = result.second;
-                    if (!hasDistinct) {
-                        return newPlan;
-                    }
-                    return disassembleDistinct(newPlan, globalOutputSMap);
-                }).toRule(RuleType.AGGREGATE_DISASSEMBLE);
+                .when(LogicalAggregate::isFinalPhase)
+                .when(LogicalAggregate::isLocal)
+                
.then(this::disassembleAggregateFunction).toRule(RuleType.AGGREGATE_DISASSEMBLE);
     }
 
-    // only support distinct function with group by
-    // TODO: support distinct function without group by. (add second global 
phase)
-    private LogicalAggregate<LogicalAggregate<LogicalAggregate<GroupPlan>>> 
disassembleDistinct(
-            LogicalAggregate<LogicalAggregate<GroupPlan>> aggregate,
-            Map<Expression, Expression> globalOutputSMap) {
-        LogicalAggregate<GroupPlan> localDistinct = aggregate.child();
-        // replace expression in globalOutputExprs and globalGroupByExprs
-        List<NamedExpression> globalOutputExprs = 
localDistinct.getOutputExpressions().stream()
-                .map(e -> ExpressionUtils.replace(e, globalOutputSMap))
-                .map(NamedExpression.class::cast)
-                .collect(Collectors.toList());
-
-        // generate new plan
-        LogicalAggregate<LogicalAggregate<GroupPlan>> globalDistinct = new 
LogicalAggregate<>(
-                localDistinct.getGroupByExpressions(),
-                globalOutputExprs,
-                Optional.of(aggregate.getGroupByExpressions()),
-                true,
-                aggregate.isNormalized(),
-                false,
-                AggPhase.GLOBAL,
-                aggregate.getSourceRepeat(),
-                localDistinct
-        );
-        return new LogicalAggregate<>(
-                aggregate.getGroupByExpressions(),
-                aggregate.getOutputExpressions(),
-                true,
-                aggregate.isNormalized(),
-                true,
-                AggPhase.DISTINCT_LOCAL,
-                aggregate.getSourceRepeat(),
-                globalDistinct
-        );
-    }
-
-    private Pair<LogicalAggregate<LogicalAggregate<GroupPlan>>, Boolean> 
disassembleAggregateFunction(
-            LogicalAggregate<GroupPlan> aggregate,
-            Map<Expression, Expression> globalOutputSMap) {
-        boolean hasDistinct = Boolean.FALSE;
+    private LogicalAggregate<LogicalAggregate<GroupPlan>> 
disassembleAggregateFunction(
+            LogicalAggregate<GroupPlan> aggregate) {
         List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
         List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
         Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
@@ -162,11 +99,8 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
             Preconditions.checkState(originGroupByExpr instanceof 
SlotReference,
                     "normalize aggregate failed to normalize group by 
expression " + originGroupByExpr.toSql());
             inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
-            globalOutputSMap.put(originGroupByExpr, originGroupByExpr);
             localOutputExprs.add((SlotReference) originGroupByExpr);
         }
-        List<Expression> distinctExprsForLocalGroupBy = Lists.newArrayList();
-        List<NamedExpression> distinctExprsForLocalOutput = 
Lists.newArrayList();
         for (NamedExpression originOutputExpr : originOutputExprs) {
             Set<AggregateFunction> aggregateFunctions
                     = 
originOutputExpr.collect(AggregateFunction.class::isInstance);
@@ -174,39 +108,19 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
                 if (inputSubstitutionMap.containsKey(aggregateFunction)) {
                     continue;
                 }
-                if (aggregateFunction.isDistinct()) {
-                    hasDistinct = Boolean.TRUE;
-                    for (Expression expr : aggregateFunction.children()) {
-                        Preconditions.checkState(expr instanceof 
SlotReference, "normalize aggregate failed to"
-                                + " normalize aggregate function " + 
aggregateFunction.toSql());
-                        distinctExprsForLocalOutput.add((SlotReference) expr);
-                        if (!inputSubstitutionMap.containsKey(expr)) {
-                            inputSubstitutionMap.put(expr, expr);
-                            globalOutputSMap.put(expr, expr);
-                        }
-                        distinctExprsForLocalGroupBy.add(expr);
-                    }
-                    continue;
-                }
-
                 AggregateFunction localAggregateFunction = 
aggregateFunction.withAggregateParam(
                         aggregateFunction.getAggregateParam()
-                                .withDistinct(false)
-                                .withGlobalAndDisassembled(false, true)
+                                .withPhaseAndDisassembled(false, 
AggPhase.LOCAL, true)
                 );
                 NamedExpression localOutputExpr = new 
Alias(localAggregateFunction, aggregateFunction.toSql());
 
                 AggregateFunction substitutionValue = aggregateFunction
                         // save the origin input types to the global aggregate 
functions
                         
.withAggregateParam(aggregateFunction.getAggregateParam()
-                                .withDistinct(false)
-                                .withGlobalAndDisassembled(true, true))
+                                .withPhaseAndDisassembled(true, 
AggPhase.GLOBAL, true))
                         
.withChildren(Lists.newArrayList(localOutputExpr.toSlot()));
 
                 inputSubstitutionMap.put(aggregateFunction, substitutionValue);
-                // because we use local output exprs to generate global output 
in disassembleDistinct,
-                // so we must use localAggregateFunction as key
-                globalOutputSMap.put(localAggregateFunction, 
substitutionValue);
                 localOutputExprs.add(localOutputExpr);
             }
         }
@@ -218,10 +132,6 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
                 .collect(Collectors.toList());
         List<Expression> globalGroupByExprs = localGroupByExprs.stream()
                 .map(e -> ExpressionUtils.replace(e, 
inputSubstitutionMap)).collect(Collectors.toList());
-        // To avoid repeated substitution of distinct expressions,
-        // here the expressions are put into the local after the substitution 
is completed
-        localOutputExprs.addAll(distinctExprsForLocalOutput);
-        localGroupByExprs.addAll(distinctExprsForLocalGroupBy);
         // 4. generate new plan
         LogicalAggregate<GroupPlan> localAggregate = new LogicalAggregate<>(
                 localGroupByExprs,
@@ -233,7 +143,7 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
                 aggregate.getSourceRepeat(),
                 aggregate.child()
         );
-        return Pair.of(new LogicalAggregate<>(
+        return new LogicalAggregate<>(
                 globalGroupByExprs,
                 globalOutputExprs,
                 true,
@@ -242,6 +152,6 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
                 AggPhase.GLOBAL,
                 aggregate.getSourceRepeat(),
                 localAggregate
-        ), hasDistinct);
+        );
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateDisassemble.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateDisassemble.java
new file mode 100644
index 0000000000..1d182cf346
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateDisassemble.java
@@ -0,0 +1,190 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.plans.AggPhase;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Used to generate the merge agg node for distributed execution.
+ * NOTICE: DISTINCT GLOBAL output expressions' ExprId should SAME with ORIGIN 
output expressions' ExprId.
+ * <pre>
+ * If we have a query: SELECT COUNT(distinct v1 * v2) + 1 FROM t
+ * the initial plan is:
+ *   +-- Aggregate(phase: [LOCAL], outputExpr: [Alias(COUNT(distinct v1 * v2) 
+ 1) #2])
+ *       +-- childPlan
+ * we should rewrite to:
+ *   Aggregate(phase: [GLOBAL DISTINCT], outputExpr: [Alias(SUM(c) + 1) #2])
+ *   +-- Aggregate(phase: [LOCAL DISTINCT], outputExpr: [SUM(b) as c] )
+ *       +-- Aggregate(phase: [GLOBAL], outputExpr: [COUNT(distinct a) as b])
+ *           +-- Aggregate(phase: [LOCAL], outputExpr: [COUNT(distinct v1 * 
v2) as a])
+ *               +-- childPlan
+ * </pre>
+ */
+public class DistinctAggregateDisassemble extends OneRewriteRuleFactory {
+
+    @Override
+    public Rule build() {
+        return logicalAggregate()
+                .when(LogicalAggregate::needDistinctDisassemble)
+                
.then(this::disassembleAggregateFunction).toRule(RuleType.DISTINCT_AGGREGATE_DISASSEMBLE);
+    }
+
+    private 
LogicalAggregate<LogicalAggregate<LogicalAggregate<LogicalAggregate<GroupPlan>>>>
+            disassembleAggregateFunction(
+            LogicalAggregate<GroupPlan> aggregate) {
+        // Double-check to prevent incorrect changes
+        Preconditions.checkArgument(aggregate.getAggPhase() == AggPhase.LOCAL);
+        Preconditions.checkArgument(aggregate.isFinalPhase());
+        List<Expression> groupByExpressions = 
aggregate.getGroupByExpressions();
+        if (groupByExpressions == null || groupByExpressions.isEmpty()) {
+            // If there are no group by expressions, in order to parallelize,
+            // we need to manually use the distinct function argument as group 
by expressions
+            groupByExpressions = new 
ArrayList<>(getDistinctFunctionParams(aggregate));
+        }
+        Pair<List<NamedExpression>, List<NamedExpression>> localAndGlobal =
+                disassemble(aggregate.getOutputExpressions(),
+                        groupByExpressions,
+                        AggPhase.LOCAL, AggPhase.GLOBAL);
+        Pair<List<NamedExpression>, List<NamedExpression>> 
globalAndDistinctLocal =
+                disassemble(localAndGlobal.second,
+                        groupByExpressions,
+                        AggPhase.GLOBAL, AggPhase.DISTINCT_LOCAL);
+        Pair<List<NamedExpression>, List<NamedExpression>> 
distinctLocalAndDistinctGlobal =
+                disassemble(globalAndDistinctLocal.second,
+                        aggregate.getGroupByExpressions(),
+                        AggPhase.DISTINCT_LOCAL, AggPhase.DISTINCT_GLOBAL);
+        // generate new plan
+        LogicalAggregate<GroupPlan> localAggregate = new LogicalAggregate<>(
+                groupByExpressions,
+                localAndGlobal.first,
+                true,
+                aggregate.isNormalized(),
+                false,
+                AggPhase.LOCAL,
+                aggregate.getSourceRepeat(),
+                aggregate.child()
+        );
+        LogicalAggregate<LogicalAggregate<GroupPlan>> globalAggregate = new 
LogicalAggregate<>(
+                groupByExpressions,
+                globalAndDistinctLocal.first,
+                true,
+                aggregate.isNormalized(),
+                false,
+                AggPhase.GLOBAL,
+                aggregate.getSourceRepeat(),
+                localAggregate
+        );
+        LogicalAggregate<LogicalAggregate<LogicalAggregate<GroupPlan>>> 
distinctLocalAggregate =
+                new LogicalAggregate<>(
+                        aggregate.getGroupByExpressions(),
+                        distinctLocalAndDistinctGlobal.first,
+                        true,
+                        aggregate.isNormalized(),
+                        false,
+                        AggPhase.DISTINCT_LOCAL,
+                        aggregate.getSourceRepeat(),
+                        globalAggregate
+                );
+        return new LogicalAggregate<>(
+                aggregate.getGroupByExpressions(),
+                distinctLocalAndDistinctGlobal.second,
+                true,
+                aggregate.isNormalized(),
+                true,
+                AggPhase.DISTINCT_GLOBAL,
+                aggregate.getSourceRepeat(),
+                distinctLocalAggregate
+        );
+    }
+
+    private Pair<List<NamedExpression>, List<NamedExpression>> disassemble(
+            List<NamedExpression> originOutputExprs,
+            List<Expression> childGroupByExprs,
+            AggPhase childPhase,
+            AggPhase parentPhase) {
+        Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
+
+        List<NamedExpression> childOutputExprs = Lists.newArrayList();
+        // The groupBy slots are placed at the beginning of the output, in 
line with the stale optimiser
+        childGroupByExprs.stream().forEach(expression -> 
childOutputExprs.add((SlotReference) expression));
+        for (NamedExpression originOutputExpr : originOutputExprs) {
+            Set<AggregateFunction> aggregateFunctions
+                    = 
originOutputExpr.collect(AggregateFunction.class::isInstance);
+            for (AggregateFunction aggregateFunction : aggregateFunctions) {
+                if (inputSubstitutionMap.containsKey(aggregateFunction)) {
+                    continue;
+                }
+                AggregateFunction childAggregateFunction = 
aggregateFunction.withAggregateParam(
+                        aggregateFunction.getAggregateParam()
+                                .withPhaseAndDisassembled(false, childPhase, 
true)
+                );
+                NamedExpression childOutputExpr = new 
Alias(childAggregateFunction, aggregateFunction.toSql());
+                AggregateFunction substitutionValue = aggregateFunction
+                        // save the origin input types to the global aggregate 
functions
+                        
.withAggregateParam(aggregateFunction.getAggregateParam()
+                                .withPhaseAndDisassembled(true, parentPhase, 
true))
+                        
.withChildren(Lists.newArrayList(childOutputExpr.toSlot()));
+
+                inputSubstitutionMap.put(aggregateFunction, substitutionValue);
+                childOutputExprs.add(childOutputExpr);
+            }
+        }
+
+        // 3. replace expression in parentOutputExprs
+        List<NamedExpression> parentOutputExprs = originOutputExprs.stream()
+                .map(e -> ExpressionUtils.replace(e, inputSubstitutionMap))
+                .map(NamedExpression.class::cast)
+                .collect(Collectors.toList());
+        return Pair.of(childOutputExprs, parentOutputExprs);
+    }
+
+    private List<Expression> 
getDistinctFunctionParams(LogicalAggregate<GroupPlan> agg) {
+        List<Expression> result = new ArrayList<>();
+        for (NamedExpression originOutputExpr : agg.getOutputExpressions()) {
+            Set<AggregateFunction> aggregateFunctions
+                    = 
originOutputExpr.collect(AggregateFunction.class::isInstance);
+            for (AggregateFunction aggregateFunction : aggregateFunctions) {
+                if (aggregateFunction.isDistinct()) {
+                    result.addAll(aggregateFunction.children());
+                }
+            }
+        }
+        return result;
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
index 80c799391b..e97071f0a7 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
@@ -38,7 +38,7 @@ public abstract class AggregateFunction extends BoundFunction 
implements Expects
     private final AggregateParam aggregateParam;
 
     public AggregateFunction(String name, Expression... arguments) {
-        this(name, AggregateParam.global(), arguments);
+        this(name, AggregateParam.finalPhase(), arguments);
     }
 
     public AggregateFunction(String name, AggregateParam aggregateParam, 
Expression... arguments) {
@@ -79,7 +79,7 @@ public abstract class AggregateFunction extends BoundFunction 
implements Expects
 
     @Override
     public final DataType getDataType() {
-        if (aggregateParam.isGlobal) {
+        if (aggregateParam.aggPhase.isGlobal() || aggregateParam.isFinalPhase) 
{
             return getFinalType();
         } else {
             return getIntermediateTypes();
@@ -114,7 +114,11 @@ public abstract class AggregateFunction extends 
BoundFunction implements Expects
     }
 
     public boolean isGlobal() {
-        return aggregateParam.isGlobal;
+        return aggregateParam.aggPhase.isGlobal();
+    }
+
+    public boolean isFinalPhase() {
+        return aggregateParam.isFinalPhase;
     }
 
     public boolean isDisassembled() {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java
index 9e4b8e1394..c96e437768 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java
@@ -17,51 +17,57 @@
 
 package org.apache.doris.nereids.trees.expressions.functions.agg;
 
+import org.apache.doris.nereids.trees.plans.AggPhase;
+
 import com.google.common.base.Preconditions;
 
 import java.util.Objects;
 
 /** AggregateParam. */
 public class AggregateParam {
-    public final boolean isGlobal;
+    public final boolean isFinalPhase;
+
+    public final AggPhase aggPhase;
 
     public final boolean isDistinct;
 
     public final boolean isDisassembled;
 
     /** AggregateParam */
-    public AggregateParam(boolean isDistinct, boolean isGlobal, boolean 
isDisassembled) {
+    public AggregateParam(boolean isDistinct, boolean isFinalPhase, AggPhase 
aggPhase, boolean isDisassembled) {
+        this.isFinalPhase = isFinalPhase;
         this.isDistinct = isDistinct;
-        this.isGlobal = isGlobal;
+        this.aggPhase = aggPhase;
         this.isDisassembled = isDisassembled;
-        if (!isGlobal) {
-            Preconditions.checkArgument(isDisassembled == true,
-                    "local aggregate should be disassembed");
+        if (!isFinalPhase) {
+            Preconditions.checkArgument(isDisassembled,
+                    "non-final phase aggregate should be disassembed");
         }
     }
 
-    public static AggregateParam global() {
-        return new AggregateParam(false, true, false);
+    public static AggregateParam finalPhase() {
+        return new AggregateParam(false, true, AggPhase.LOCAL, false);
     }
 
-    public static AggregateParam distinctAndGlobal() {
-        return new AggregateParam(true, true, false);
+    public static AggregateParam distinctAndFinalPhase() {
+        return new AggregateParam(true, true, AggPhase.LOCAL, false);
     }
 
     public AggregateParam withDistinct(boolean isDistinct) {
-        return new AggregateParam(isDistinct, isGlobal, isDisassembled);
+        return new AggregateParam(isDistinct, isFinalPhase, aggPhase, 
isDisassembled);
     }
 
-    public AggregateParam withGlobal(boolean isGlobal) {
-        return new AggregateParam(isDistinct, isGlobal, isDisassembled);
+    public AggregateParam withAggPhase(AggPhase aggPhase) {
+        return new AggregateParam(isDistinct, isFinalPhase, aggPhase, 
isDisassembled);
     }
 
     public AggregateParam withDisassembled(boolean isDisassembled) {
-        return new AggregateParam(isDistinct, isGlobal, isDisassembled);
+        return new AggregateParam(isDistinct, isFinalPhase, aggPhase, 
isDisassembled);
     }
 
-    public AggregateParam withGlobalAndDisassembled(boolean isGlobal, boolean 
isDisassembled) {
-        return new AggregateParam(isDistinct, isGlobal, isDisassembled);
+    public AggregateParam withPhaseAndDisassembled(boolean isFinalPhase, 
AggPhase aggPhase,
+                                                      boolean isDisassembled) {
+        return new AggregateParam(isDistinct, isFinalPhase, aggPhase, 
isDisassembled);
     }
 
     @Override
@@ -74,12 +80,13 @@ public class AggregateParam {
         }
         AggregateParam that = (AggregateParam) o;
         return isDistinct == that.isDistinct
-                && Objects.equals(isGlobal, that.isGlobal)
+                && isFinalPhase == that.isFinalPhase
+                && Objects.equals(aggPhase, that.aggPhase)
                 && Objects.equals(isDisassembled, that.isDisassembled);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(isDistinct, isGlobal, isDisassembled);
+        return Objects.hash(isDistinct, isFinalPhase, aggPhase, 
isDisassembled);
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
index db5d703c4f..e94d32dfba 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
@@ -22,6 +22,7 @@ import org.apache.doris.nereids.properties.LogicalProperties;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import org.apache.doris.nereids.trees.plans.AggPhase;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.PlanType;
@@ -35,6 +36,7 @@ import com.google.common.collect.ImmutableList;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
+import java.util.Set;
 
 /**
  * Logical Aggregate plan.
@@ -204,10 +206,41 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
                 .build();
     }
 
+    public boolean isLocal() {
+        return aggPhase.isLocal();
+    }
+
     public boolean isDisassembled() {
         return disassembled;
     }
 
+    /**
+     * Check if disassembling is possible
+     * @return true means that disassembling is possible
+     */
+    public boolean needDistinctDisassemble() {
+        // It is sufficient to split an aggregate function with a groupBy 
expression into two stages(Local & Global),
+        // no need for four stages(Local & Global & Distinct Local & Distinct 
Global)
+        if (!isFinalPhase || aggPhase != AggPhase.LOCAL) {
+            return false;
+        }
+        int distinctFunctionCount = 0;
+        for (NamedExpression originOutputExpr : outputExpressions) {
+            Set<AggregateFunction> aggregateFunctions
+                    = 
originOutputExpr.collect(AggregateFunction.class::isInstance);
+            for (AggregateFunction aggregateFunction : aggregateFunctions) {
+                if (aggregateFunction.isDistinct()) {
+                    distinctFunctionCount++;
+                    if (distinctFunctionCount > 1) {
+                        return false;
+                    }
+                }
+            }
+        }
+        // Only one distinct function is supported
+        return distinctFunctionCount == 1;
+    }
+
     public boolean isNormalized() {
         return normalized;
     }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
index f10b6325a2..9d4d5711e2 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.rules.rewrite.logical;
 
 import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
+import org.apache.doris.nereids.rules.rewrite.DistinctAggregateDisassemble;
 import org.apache.doris.nereids.trees.expressions.Add;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
@@ -44,6 +45,7 @@ import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.TestInstance;
 
+import java.util.ArrayList;
 import java.util.List;
 
 @TestInstance(TestInstance.Lifecycle.PER_CLASS)
@@ -206,121 +208,147 @@ public class AggregateDisassembleTest implements 
PatternMatchSupported {
     /**
      * <pre>
      * the initial plan is:
-     *   Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age) + 2) as 
c], groupByExpr: [id])
+     *   Aggregate(phase: [LOCAL], outputExpr: [(COUNT(distinct age) + 2) as 
c], groupByExpr: [])
      *   +-- childPlan(id, name, age)
      * we should rewrite to:
-     *   Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [(COUNT(distinct age) 
+ 2) as c], groupByExpr: [id])
-     *   +-- Aggregate(phase: [GLOBAL], outputExpr: [id, age], groupByExpr: 
[id, age])
-     *       +-- Aggregate(phase: [LOCAL], outputExpr: [id, age], groupByExpr: 
[id, age])
-     *           +-- childPlan(id, name, age)
+     *   Aggregate(phase: [GLOBAL], outputExpr: [count(distinct c)], 
groupByExpr: [])
+     *   +-- Aggregate(phase: [LOCAL], outputExpr: [(COUNT(distinct age) + 2) 
as c], groupByExpr: [])
+     *       +-- childPlan(id, name, age)
      * </pre>
      */
     @Test
-    public void distinctAggregateWithGroupBy() {
-        List<Expression> groupExpressionList = 
Lists.newArrayList(rStudent.getOutput().get(0).toSlot());
+    public void distinctAggregateWithoutGroupByApply2PhaseRule() {
+        List<Expression> groupExpressionList = new ArrayList<>();
         List<NamedExpression> outputExpressionList = Lists.newArrayList(new 
Alias(
-                new Add(new Count(AggregateParam.distinctAndGlobal(), 
rStudent.getOutput().get(2).toSlot()),
+                new Add(new Count(AggregateParam.distinctAndFinalPhase(), 
rStudent.getOutput().get(2).toSlot()),
                         new IntegerLiteral(2)), "c"));
         Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList, rStudent);
 
+        PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new AggregateDisassemble())
+                .matchesFromRoot(
+                    logicalAggregate(
+                            logicalAggregate()
+                                    .when(agg -> 
agg.getAggPhase().equals(AggPhase.LOCAL))
+                                    .when(agg -> 
agg.getOutputExpressions().size() == 1)
+                                    .when(agg -> 
agg.getGroupByExpressions().isEmpty())
+                    ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL))
+                            .when(agg -> agg.getOutputExpressions().size() == 
1)
+                            .when(agg -> agg.getGroupByExpressions().isEmpty())
+                );
+    }
+
+    @Test
+    public void distinctWithNormalAggregateFunctionApply2PhaseRule() {
+        List<Expression> groupExpressionList = 
Lists.newArrayList(rStudent.getOutput().get(0).toSlot());
+        List<NamedExpression> outputExpressionList = Lists.newArrayList(
+                new Alias(new Count(AggregateParam.distinctAndFinalPhase(), 
rStudent.getOutput().get(2).toSlot()), "c"),
+                new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), 
"sum"));
+        Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList, rStudent);
+
         // check local:
         // id
         Expression localOutput0 = rStudent.getOutput().get(0);
-        // age
-        Expression localOutput1 = rStudent.getOutput().get(2);
+        // count
+        Count localOutput1 = new Count(new AggregateParam(true, false, 
AggPhase.LOCAL, true), rStudent.getOutput().get(2).toSlot());
+        // sum
+        Sum localOutput2 = new Sum(new AggregateParam(false, false, 
AggPhase.LOCAL, true), rStudent.getOutput().get(0).toSlot());
         // id
         Expression localGroupBy0 = rStudent.getOutput().get(0);
-        // age
-        Expression localGroupBy1 = rStudent.getOutput().get(2);
 
         PlanChecker.from(MemoTestUtils.createConnectContext(), root)
                 .applyTopDown(new AggregateDisassemble())
                 .matchesFromRoot(
-                        logicalAggregate(
-                                logicalAggregate(
-                                        logicalAggregate()
-                                                .when(agg -> 
agg.getAggPhase().equals(AggPhase.LOCAL))
-                                                .when(agg -> 
agg.getOutputExpressions().get(0).equals(localOutput0))
-                                                .when(agg -> 
agg.getOutputExpressions().get(1).equals(localOutput1))
-                                                .when(agg -> 
agg.getGroupByExpressions().get(0).equals(localGroupBy0))
-                                                .when(agg -> 
agg.getGroupByExpressions().get(1).equals(localGroupBy1))
-                                ).when(agg -> 
agg.getAggPhase().equals(AggPhase.GLOBAL))
-                                        .when(agg -> 
agg.getOutputExpressions().get(0)
-                                                
.equals(agg.child().getOutputExpressions().get(0)))
-                                        .when(agg -> 
agg.getOutputExpressions().get(1)
-                                                
.equals(agg.child().getOutputExpressions().get(1)))
-                                        .when(agg -> 
agg.getGroupByExpressions().get(0)
-                                                
.equals(agg.child().getOutputExpressions().get(0)))
-                                        .when(agg -> 
agg.getGroupByExpressions().get(1)
-                                                
.equals(agg.child().getOutputExpressions().get(1)))
-                        ).when(agg -> 
agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL))
-                                .when(agg -> agg.getOutputExpressions().size() 
== 1)
-                                .when(agg -> agg.getOutputExpressions().get(0) 
instanceof Alias)
-                                .when(agg -> 
agg.getOutputExpressions().get(0).child(0) instanceof Add)
-                                .when(agg -> agg.getGroupByExpressions().get(0)
-                                        
.equals(agg.child().child().getOutputExpressions().get(0)))
-                                .when(agg -> 
agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get(
-                                        0).getExprId())
+                    logicalAggregate(
+                            logicalAggregate()
+                                    .when(agg -> 
agg.getAggPhase().equals(AggPhase.LOCAL))
+                                    .when(agg -> 
agg.getOutputExpressions().get(0).equals(localOutput0))
+                                    .when(agg -> 
agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
+                                    .when(agg -> 
agg.getOutputExpressions().get(2).child(0).equals(localOutput2))
+                                    .when(agg -> 
agg.getGroupByExpressions().get(0).equals(localGroupBy0))
+                    ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL))
+                            .when(agg -> {
+                                Slot child = 
agg.child().getOutputExpressions().get(1).toSlot();
+                                
Assertions.assertTrue(agg.getOutputExpressions().get(0).child(0) instanceof 
Count);
+                                return 
agg.getOutputExpressions().get(0).child(0).child(0).equals(child);
+                            })
+                            .when(agg -> {
+                                Slot child = 
agg.child().getOutputExpressions().get(2).toSlot();
+                                
Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof 
Sum);
+                                return ((Sum) 
agg.getOutputExpressions().get(1).child(0)).child().equals(child);
+                            })
+                            .when(agg -> agg.getGroupByExpressions().get(0)
+                                    
.equals(agg.child().getOutputExpressions().get(0)))
                 );
     }
 
     @Test
-    public void distinctWithNormalAggregateFunction() {
+    public void distinctWithNormalAggregateFunctionApply4PhaseRule() {
         List<Expression> groupExpressionList = 
Lists.newArrayList(rStudent.getOutput().get(0).toSlot());
         List<NamedExpression> outputExpressionList = Lists.newArrayList(
-                new Alias(new Count(AggregateParam.distinctAndGlobal(), 
rStudent.getOutput().get(2).toSlot()), "c"),
+                new Alias(new Count(AggregateParam.distinctAndFinalPhase(), 
rStudent.getOutput().get(2).toSlot()), "c"),
                 new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), 
"sum"));
         Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList, rStudent);
 
         // check local:
         // id
         Expression localOutput0 = rStudent.getOutput().get(0);
+        // count
+        Count localOutput1 = new Count(new AggregateParam(true, false, 
AggPhase.LOCAL, true), rStudent.getOutput().get(2).toSlot());
         // sum
-        Sum localOutput1 = new Sum(new AggregateParam(false, false, true), 
rStudent.getOutput().get(0).toSlot());
-        // age
-        Expression localOutput2 = rStudent.getOutput().get(2);
+        Sum localOutput2 = new Sum(new AggregateParam(false, false, 
AggPhase.LOCAL, true), rStudent.getOutput().get(0).toSlot());
         // id
         Expression localGroupBy0 = rStudent.getOutput().get(0);
-        // age
-        Expression localGroupBy1 = rStudent.getOutput().get(2);
 
         PlanChecker.from(MemoTestUtils.createConnectContext(), root)
-                .applyTopDown(new AggregateDisassemble())
+                .applyTopDown(new DistinctAggregateDisassemble())
                 .matchesFromRoot(
+                logicalAggregate(
+                    logicalAggregate(
                         logicalAggregate(
-                                logicalAggregate(
-                                        logicalAggregate()
-                                                .when(agg -> 
agg.getAggPhase().equals(AggPhase.LOCAL))
-                                                .when(agg -> 
agg.getOutputExpressions().get(0).equals(localOutput0))
-                                                .when(agg -> 
agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
-                                                .when(agg -> 
agg.getOutputExpressions().get(2).equals(localOutput2))
-                                                .when(agg -> 
agg.getGroupByExpressions().get(0).equals(localGroupBy0))
-                                                .when(agg -> 
agg.getGroupByExpressions().get(1).equals(localGroupBy1))
-                                ).when(agg -> 
agg.getAggPhase().equals(AggPhase.GLOBAL))
-                                        .when(agg -> 
agg.getOutputExpressions().get(0)
-                                                
.equals(agg.child().getOutputExpressions().get(0)))
-                                        .when(agg -> {
-                                            Slot child = 
agg.child().getOutputExpressions().get(1).toSlot();
-                                            
Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof 
Sum);
-                                            return ((Sum) 
agg.getOutputExpressions().get(1).child(0)).child().equals(child);
-                                        })
-                                        .when(agg -> 
agg.getOutputExpressions().get(2)
-                                                
.equals(agg.child().getOutputExpressions().get(2)))
-                                        .when(agg -> 
agg.getGroupByExpressions().get(0)
-                                                
.equals(agg.child().getOutputExpressions().get(0)))
-                                        .when(agg -> 
agg.getGroupByExpressions().get(1)
-                                                
.equals(agg.child().getOutputExpressions().get(2)))
-                        ).when(agg -> 
agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL))
-                                .when(agg -> agg.getOutputExpressions().size() 
== 2)
-                                .when(agg -> agg.getOutputExpressions().get(0) 
instanceof Alias)
-                                .when(agg -> 
agg.getOutputExpressions().get(0).child(0) instanceof Count)
-                                .when(agg -> 
agg.getOutputExpressions().get(1).child(0) instanceof Sum)
-                                .when(agg -> 
agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get(
-                                        0).getExprId())
-                                .when(agg -> 
agg.getOutputExpressions().get(1).getExprId() == outputExpressionList.get(
-                                        1).getExprId())
+                                logicalAggregate()
+                                        .when(agg -> 
agg.getAggPhase().equals(AggPhase.LOCAL))
+                                        .when(agg -> 
agg.getOutputExpressions().get(0).equals(localOutput0))
+                                        .when(agg -> 
agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
+                                        .when(agg -> 
agg.getOutputExpressions().get(2).child(0).equals(localOutput2))
+                                        .when(agg -> 
agg.getGroupByExpressions().get(0).equals(localGroupBy0))
+                        ).when(agg -> 
agg.getAggPhase().equals(AggPhase.GLOBAL))
+                                .when(agg -> {
+                                    Slot child = 
agg.child().getOutputExpressions().get(1).toSlot();
+                                    
Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof 
Count);
+                                    return 
agg.getOutputExpressions().get(1).child(0).child(0).equals(child);
+                                })
+                                .when(agg -> {
+                                    Slot child = 
agg.child().getOutputExpressions().get(2).toSlot();
+                                    
Assertions.assertTrue(agg.getOutputExpressions().get(2).child(0) instanceof 
Sum);
+                                    return ((Sum) 
agg.getOutputExpressions().get(2).child(0)).child().equals(child);
+                                })
                                 .when(agg -> agg.getGroupByExpressions().get(0)
-                                        
.equals(agg.child().child().getOutputExpressions().get(0)))
+                                        
.equals(agg.child().getOutputExpressions().get(0)))
+                    ).when(agg -> 
agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL))
+                            .when(agg -> {
+                                Slot child = 
agg.child().getOutputExpressions().get(1).toSlot();
+                                
Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof 
Count);
+                                return 
agg.getOutputExpressions().get(1).child(0).child(0).equals(child);
+                            })
+                            .when(agg -> {
+                                Slot child = 
agg.child().getOutputExpressions().get(2).toSlot();
+                                
Assertions.assertTrue(agg.getOutputExpressions().get(2).child(0) instanceof 
Sum);
+                                return ((Sum) 
agg.getOutputExpressions().get(2).child(0)).child().equals(child);
+                            })
+                            .when(agg -> agg.getGroupByExpressions().get(0)
+                                    
.equals(agg.child().getOutputExpressions().get(0)))
+                ).when(agg -> 
agg.getAggPhase().equals(AggPhase.DISTINCT_GLOBAL))
+                        .when(agg -> agg.getOutputExpressions().size() == 2)
+                        .when(agg -> agg.getOutputExpressions().get(0) 
instanceof Alias)
+                        .when(agg -> 
agg.getOutputExpressions().get(0).child(0) instanceof Count)
+                        .when(agg -> 
agg.getOutputExpressions().get(1).child(0) instanceof Sum)
+                        .when(agg -> 
agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get(
+                                0).getExprId())
+                        .when(agg -> 
agg.getOutputExpressions().get(1).getExprId() == outputExpressionList.get(
+                                1).getExprId())
+                        .when(agg -> agg.getGroupByExpressions().get(0)
+                                
.equals(agg.child().child().child().getOutputExpressions().get(0)))
                 );
     }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
index dd8ea662d0..c3bd3830cc 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
@@ -177,13 +177,13 @@ public class ExpressionEqualsTest {
         Assertions.assertEquals(count1, count2);
         Assertions.assertEquals(count1.hashCode(), count2.hashCode());
 
-        Count count3 = new Count(AggregateParam.distinctAndGlobal(), child1);
-        Count count4 = new Count(AggregateParam.distinctAndGlobal(), child2);
+        Count count3 = new Count(AggregateParam.distinctAndFinalPhase(), 
child1);
+        Count count4 = new Count(AggregateParam.distinctAndFinalPhase(), 
child2);
         Assertions.assertEquals(count3, count4);
         Assertions.assertEquals(count3.hashCode(), count4.hashCode());
 
         // bad case
-        Count count5 = new Count(AggregateParam.distinctAndGlobal(), child1);
+        Count count5 = new Count(AggregateParam.distinctAndFinalPhase(), 
child1);
         Count count6 = new Count(child2);
         Assertions.assertNotEquals(count5, count6);
         Assertions.assertNotEquals(count5.hashCode(), count6.hashCode());


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to