This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 90f11ed7c1 [enhancement](Nereids) remove unnecessary exchange between 
global and distinct local aggregate node (#13057)
90f11ed7c1 is described below

commit 90f11ed7c1116a48a7a10acb0a6bf50682fddd19
Author: morrySnow <[email protected]>
AuthorDate: Thu Sep 29 23:12:37 2022 +0800

    [enhancement](Nereids) remove unnecessary exchange between global and 
distinct local aggregate node (#13057)
    
    Add partition info into LogicalAggregate and set it as original group 
expression list of aggregate when we do aggregate disassemble with distinct 
aggregate function.
---
 .../nereids/jobs/cascades/CostAndEnforcerJob.java  |  5 +-
 .../properties/ChildOutputPropertyDeriver.java     | 17 ++--
 .../LogicalAggToPhysicalHashAgg.java               |  4 +-
 .../rules/rewrite/AggregateDisassemble.java        | 75 +++++++----------
 .../rules/rewrite/logical/NormalizeAggregate.java  |  5 +-
 .../trees/plans/logical/LogicalAggregate.java      | 47 ++++++++---
 .../trees/plans/physical/PhysicalAggregate.java    | 33 ++++----
 .../properties/ChildOutputPropertyDeriverTest.java |  5 +-
 .../rewrite/logical/AggregateDisassembleTest.java  | 96 ++++------------------
 9 files changed, 114 insertions(+), 173 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java
index a11161f069..188acaeba5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java
@@ -212,10 +212,9 @@ public class CostAndEnforcerJob extends Job implements 
Cloneable {
 
     private void enforce(PhysicalProperties outputProperty, 
List<PhysicalProperties> requestChildrenProperty) {
         PhysicalProperties requiredProperties = 
context.getRequiredProperties();
-
-        EnforceMissingPropertiesHelper enforceMissingPropertiesHelper
-                = new EnforceMissingPropertiesHelper(context, groupExpression, 
curTotalCost);
         if (!outputProperty.satisfy(requiredProperties)) {
+            EnforceMissingPropertiesHelper enforceMissingPropertiesHelper
+                    = new EnforceMissingPropertiesHelper(context, 
groupExpression, curTotalCost);
             PhysicalProperties addEnforcedProperty = 
enforceMissingPropertiesHelper
                     .enforceProperty(outputProperty, requiredProperties);
             curTotalCost = enforceMissingPropertiesHelper.getCurTotalCost();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
index 43cedd78a4..e1ff5168ad 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
@@ -20,8 +20,6 @@ package org.apache.doris.nereids.properties;
 import org.apache.doris.nereids.PlanContext;
 import org.apache.doris.nereids.memo.GroupExpression;
 import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
-import org.apache.doris.nereids.trees.expressions.ExprId;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
@@ -42,7 +40,6 @@ import com.google.common.base.Preconditions;
 
 import java.util.List;
 import java.util.Objects;
-import java.util.stream.Collectors;
 
 /**
  * Used for property drive.
@@ -76,18 +73,14 @@ public class ChildOutputPropertyDeriver extends 
PlanVisitor<PhysicalProperties,
         // TODO: add distinct phase output properties
         switch (agg.getAggPhase()) {
             case LOCAL:
-                return new 
PhysicalProperties(childOutputProperty.getDistributionSpec());
             case GLOBAL:
             case DISTINCT_LOCAL:
-                List<ExprId> columns = agg.getPartitionExpressions().stream()
-                        .map(SlotReference.class::cast)
-                        .map(SlotReference::getExprId)
-                        .collect(Collectors.toList());
-                if (columns.isEmpty()) {
-                    return PhysicalProperties.GATHER;
+                DistributionSpec childSpec = 
childOutputProperty.getDistributionSpec();
+                if (childSpec instanceof DistributionSpecHash) {
+                    DistributionSpecHash distributionSpecHash = 
(DistributionSpecHash) childSpec;
+                    return new 
PhysicalProperties(distributionSpecHash.withShuffleType(ShuffleType.BUCKETED));
                 }
-                // TODO: change ENFORCED back to bucketed, when coordinator 
could process bucket on agg correctly.
-                return PhysicalProperties.createHash(new 
DistributionSpecHash(columns, ShuffleType.BUCKETED));
+                return new 
PhysicalProperties(childOutputProperty.getDistributionSpec());
             case DISTINCT_GLOBAL:
             default:
                 throw new RuntimeException("Could not derive output properties 
for agg phase: " + agg.getAggPhase());
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
index 4e4d52b551..56e94eb740 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
@@ -21,8 +21,6 @@ import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
 
-import com.google.common.collect.ImmutableList;
-
 /**
  * Implementation rule that convert logical aggregation to physical hash 
aggregation.
  */
@@ -33,7 +31,7 @@ public class LogicalAggToPhysicalHashAgg extends 
OneImplementationRuleFactory {
                 // TODO: for use a function to judge whether use stream
                 agg.getGroupByExpressions(),
                 agg.getOutputExpressions(),
-                ImmutableList.of(),
+                agg.getPartitionExpressions(),
                 agg.getAggPhase(),
                 false,
                 agg.isFinalPhase(),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
index 77135074d4..911b6735ac 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
@@ -30,17 +30,20 @@ import org.apache.doris.nereids.trees.plans.GroupPlan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
+import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 
 /**
  * Used to generate the merge agg node for distributed execution.
  * NOTICE: GLOBAL output expressions' ExprId should SAME with ORIGIN output 
expressions' ExprId.
+ * <pre>
  * If we have a query: SELECT SUM(v1 * v2) + 1 FROM t GROUP BY k + 1
  * the initial plan is:
  *   Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1, Alias(SUM(v1 * 
v2) + 1) #2], groupByExpr: [k + 1])
@@ -61,6 +64,7 @@ import java.util.stream.Collectors;
  *   +-- Aggregate(phase: [GLOBAL], outputExpr: [b, a], groupByExpr: [b, a])
  *       +-- Aggregate(phase: [LOCAL], outputExpr: [(k + 1) as b, (v1 * v2) as 
a], groupByExpr: [k + 1, a])
  *           +-- childPlan
+ * </pre>
  *
  * TODO:
  *     1. use different class represent different phase aggregate
@@ -75,37 +79,32 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
                 .then(aggregate -> {
                     // used in secondDisassemble to transform local 
expressions into global
                     final Map<Expression, Expression> globalOutputSMap = 
Maps.newHashMap();
-                    // used in secondDisassemble to transform local 
expressions into global
-                    final Map<Expression, Expression> globalGroupBySMap = 
Maps.newHashMap();
-                    Pair<LogicalAggregate, Boolean> ret = 
firstDisassemble(aggregate, globalOutputSMap,
-                            globalGroupBySMap);
+                    Pair<LogicalAggregate<LogicalAggregate<GroupPlan>>, 
Boolean> ret
+                            = firstDisassemble(aggregate, globalOutputSMap);
                     if (!ret.second) {
                         return ret.first;
                     }
-                    return secondDisassemble(ret.first, globalOutputSMap, 
globalGroupBySMap);
+                    return secondDisassemble(ret.first, globalOutputSMap);
                 }).toRule(RuleType.AGGREGATE_DISASSEMBLE);
     }
 
     // only support distinct function with group by
     // TODO: support distinct function without group by. (add second global 
phase)
-    private LogicalAggregate secondDisassemble(
-            LogicalAggregate<LogicalAggregate> aggregate,
-            Map<Expression, Expression> globalOutputSMap,
-            Map<Expression, Expression> globalGroupBySMap) {
+    private LogicalAggregate<LogicalAggregate<LogicalAggregate<GroupPlan>>> 
secondDisassemble(
+            LogicalAggregate<LogicalAggregate<GroupPlan>> aggregate,
+            Map<Expression, Expression> globalOutputSMap) {
         LogicalAggregate<GroupPlan> local = aggregate.child();
         // replace expression in globalOutputExprs and globalGroupByExprs
         List<NamedExpression> globalOutputExprs = 
local.getOutputExpressions().stream()
                 .map(e -> ExpressionUtils.replace(e, globalOutputSMap))
                 .map(NamedExpression.class::cast)
                 .collect(Collectors.toList());
-        List<Expression> globalGroupByExprs = 
local.getGroupByExpressions().stream()
-                .map(e -> ExpressionUtils.replace(e, globalGroupBySMap))
-                .collect(Collectors.toList());
 
         // generate new plan
-        LogicalAggregate globalAggregate = new LogicalAggregate<>(
-                globalGroupByExprs,
+        LogicalAggregate<LogicalAggregate<GroupPlan>> globalAggregate = new 
LogicalAggregate<>(
+                local.getGroupByExpressions(),
                 globalOutputExprs,
+                Optional.of(aggregate.getGroupByExpressions()),
                 true,
                 aggregate.isNormalized(),
                 false,
@@ -123,11 +122,10 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
         );
     }
 
-    private Pair<LogicalAggregate, Boolean> firstDisassemble(
+    private Pair<LogicalAggregate<LogicalAggregate<GroupPlan>>, Boolean> 
firstDisassemble(
             LogicalAggregate<GroupPlan> aggregate,
-            Map<Expression, Expression> globalOutputSMap,
-            Map<Expression, Expression> globalGroupBySMap) {
-        Boolean hasDistinct = Boolean.FALSE;
+            Map<Expression, Expression> globalOutputSMap) {
+        boolean hasDistinct = Boolean.FALSE;
         List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
         List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
         Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
@@ -155,18 +153,12 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
             if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
                 continue;
             }
-            if (originGroupByExpr instanceof SlotReference) {
-                inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
-                globalOutputSMap.put(originGroupByExpr, originGroupByExpr);
-                globalGroupBySMap.put(originGroupByExpr, originGroupByExpr);
-                localOutputExprs.add((SlotReference) originGroupByExpr);
-            } else {
-                NamedExpression localOutputExpr = new Alias(originGroupByExpr, 
originGroupByExpr.toSql());
-                inputSubstitutionMap.put(originGroupByExpr, 
localOutputExpr.toSlot());
-                globalOutputSMap.put(localOutputExpr, 
localOutputExpr.toSlot());
-                globalGroupBySMap.put(originGroupByExpr, 
localOutputExpr.toSlot());
-                localOutputExprs.add(localOutputExpr);
-            }
+            // group by expr must be SlotReference or NormalizeAggregate has 
bugs.
+            Preconditions.checkState(originGroupByExpr instanceof 
SlotReference,
+                    "normalize aggregate failed to normalize group by 
expression " + originGroupByExpr.toSql());
+            inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
+            globalOutputSMap.put(originGroupByExpr, originGroupByExpr);
+            localOutputExprs.add((SlotReference) originGroupByExpr);
         }
         List<Expression> distinctExprsForLocalGroupBy = Lists.newArrayList();
         List<NamedExpression> distinctExprsForLocalOutput = 
Lists.newArrayList();
@@ -180,21 +172,12 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
                 if (aggregateFunction.isDistinct()) {
                     hasDistinct = Boolean.TRUE;
                     for (Expression expr : aggregateFunction.children()) {
-                        if (expr instanceof SlotReference) {
-                            distinctExprsForLocalOutput.add((SlotReference) 
expr);
-                            if (!inputSubstitutionMap.containsKey(expr)) {
-                                inputSubstitutionMap.put(expr, expr);
-                                globalOutputSMap.put(expr, expr);
-                                globalGroupBySMap.put(expr, expr);
-                            }
-                        } else {
-                            NamedExpression globalOutputExpr = new Alias(expr, 
expr.toSql());
-                            distinctExprsForLocalOutput.add(globalOutputExpr);
-                            if (!inputSubstitutionMap.containsKey(expr)) {
-                                inputSubstitutionMap.put(expr, 
globalOutputExpr.toSlot());
-                                globalOutputSMap.put(globalOutputExpr, 
globalOutputExpr.toSlot());
-                                globalGroupBySMap.put(expr, 
globalOutputExpr.toSlot());
-                            }
+                        Preconditions.checkState(expr instanceof 
SlotReference, "normalize aggregate failed to"
+                                + " normalize aggregate function " + 
aggregateFunction.toSql());
+                        distinctExprsForLocalOutput.add((SlotReference) expr);
+                        if (!inputSubstitutionMap.containsKey(expr)) {
+                            inputSubstitutionMap.put(expr, expr);
+                            globalOutputSMap.put(expr, expr);
                         }
                         distinctExprsForLocalGroupBy.add(expr);
                     }
@@ -221,7 +204,7 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
         localOutputExprs.addAll(distinctExprsForLocalOutput);
         localGroupByExprs.addAll(distinctExprsForLocalGroupBy);
         // 4. generate new plan
-        LogicalAggregate localAggregate = new LogicalAggregate<>(
+        LogicalAggregate<GroupPlan> localAggregate = new LogicalAggregate<>(
                 localGroupByExprs,
                 localOutputExprs,
                 true,
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
index a888e06ce4..e1a2a832d3 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
@@ -41,8 +41,9 @@ import java.util.Set;
 import java.util.stream.Collectors;
 
 /**
- * normalize aggregate's group keys to SlotReference and generate a 
LogicalProject top on LogicalAggregate
- * to hold to order of aggregate output, since aggregate output's order could 
change when we do translate.
+ * normalize aggregate's group keys and AggregateFunction's child to 
SlotReference
+ * and generate a LogicalProject top on LogicalAggregate to hold to order of 
aggregate output,
+ * since aggregate output's order could change when we do translate.
  *
  * Apply this rule could simplify the processing of enforce and translate.
  *
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
index 06df06b92b..3dfd2ab06d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
@@ -55,9 +55,11 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
 
     private final boolean disassembled;
     private final boolean normalized;
+    private final AggPhase aggPhase;
     private final List<Expression> groupByExpressions;
     private final List<NamedExpression> outputExpressions;
-    private final AggPhase aggPhase;
+    // TODO: we should decide partition expression according to cost.
+    private final Optional<List<Expression>> partitionExpressions;
 
     // use for scenes containing distinct agg
     // 1. If there is LOCAL only, LOCAL is the final phase
@@ -74,7 +76,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
             List<Expression> groupByExpressions,
             List<NamedExpression> outputExpressions,
             CHILD_TYPE child) {
-        this(groupByExpressions, outputExpressions, false, false, true, 
AggPhase.LOCAL, child);
+        this(groupByExpressions, outputExpressions, Optional.empty(), false, 
false, true, AggPhase.LOCAL, child);
     }
 
     public LogicalAggregate(
@@ -85,7 +87,20 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
             boolean isFinalPhase,
             AggPhase aggPhase,
             CHILD_TYPE child) {
-        this(groupByExpressions, outputExpressions, disassembled, normalized, 
isFinalPhase,
+        this(groupByExpressions, outputExpressions, Optional.empty(), 
disassembled, normalized, isFinalPhase,
+                aggPhase, Optional.empty(), Optional.empty(), child);
+    }
+
+    public LogicalAggregate(
+            List<Expression> groupByExpressions,
+            List<NamedExpression> outputExpressions,
+            Optional<List<Expression>> partitionExpressions,
+            boolean disassembled,
+            boolean normalized,
+            boolean isFinalPhase,
+            AggPhase aggPhase,
+            CHILD_TYPE child) {
+        this(groupByExpressions, outputExpressions, partitionExpressions, 
disassembled, normalized, isFinalPhase,
                 aggPhase, Optional.empty(), Optional.empty(), child);
     }
 
@@ -95,6 +110,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
     public LogicalAggregate(
             List<Expression> groupByExpressions,
             List<NamedExpression> outputExpressions,
+            Optional<List<Expression>> partitionExpressions,
             boolean disassembled,
             boolean normalized,
             boolean isFinalPhase,
@@ -105,6 +121,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
         super(PlanType.LOGICAL_AGGREGATE, groupExpression, logicalProperties, 
child);
         this.groupByExpressions = groupByExpressions;
         this.outputExpressions = outputExpressions;
+        this.partitionExpressions = partitionExpressions;
         this.disassembled = disassembled;
         this.normalized = normalized;
         this.isFinalPhase = isFinalPhase;
@@ -119,6 +136,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
         return outputExpressions;
     }
 
+    public List<Expression> getPartitionExpressions() {
+        return partitionExpressions.orElse(groupByExpressions);
+    }
+
     public AggPhase getAggPhase() {
         return aggPhase;
     }
@@ -141,7 +162,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
 
     @Override
     public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
-        return visitor.visitLogicalAggregate((LogicalAggregate<Plan>) this, 
context);
+        return visitor.visitLogicalAggregate(this, context);
     }
 
     @Override
@@ -177,6 +198,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
         LogicalAggregate that = (LogicalAggregate) o;
         return Objects.equals(groupByExpressions, that.groupByExpressions)
                 && Objects.equals(outputExpressions, that.outputExpressions)
+                && Objects.equals(partitionExpressions, 
that.partitionExpressions)
                 && aggPhase == that.aggPhase
                 && disassembled == that.disassembled
                 && normalized == that.normalized
@@ -185,31 +207,34 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
 
     @Override
     public int hashCode() {
-        return Objects.hash(groupByExpressions, outputExpressions, aggPhase, 
normalized, disassembled, isFinalPhase);
+        return Objects.hash(groupByExpressions, outputExpressions, 
partitionExpressions,
+                aggPhase, normalized, disassembled, isFinalPhase);
     }
 
     @Override
     public LogicalAggregate<Plan> withChildren(List<Plan> children) {
         Preconditions.checkArgument(children.size() == 1);
-        return new LogicalAggregate<>(groupByExpressions, outputExpressions,
+        return new LogicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions,
                 disassembled, normalized, isFinalPhase, aggPhase, 
children.get(0));
     }
 
     @Override
     public LogicalAggregate<Plan> 
withGroupExpression(Optional<GroupExpression> groupExpression) {
-        return new LogicalAggregate<>(groupByExpressions, outputExpressions, 
disassembled, normalized, isFinalPhase,
-                aggPhase, groupExpression, 
Optional.of(getLogicalProperties()), children.get(0));
+        return new LogicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions,
+                disassembled, normalized, isFinalPhase, aggPhase,
+                groupExpression, Optional.of(getLogicalProperties()), 
children.get(0));
     }
 
     @Override
     public LogicalAggregate<Plan> 
withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
-        return new LogicalAggregate<>(groupByExpressions, outputExpressions, 
disassembled, normalized, isFinalPhase,
-                aggPhase, Optional.empty(), logicalProperties, 
children.get(0));
+        return new LogicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions,
+                disassembled, normalized, isFinalPhase, aggPhase,
+                Optional.empty(), logicalProperties, children.get(0));
     }
 
     public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression> 
groupByExprList,
             List<NamedExpression> outputExpressionList) {
-        return new LogicalAggregate<>(groupByExprList, outputExpressionList,
+        return new LogicalAggregate<>(groupByExprList, outputExpressionList, 
partitionExpressions,
                 disassembled, normalized, isFinalPhase, aggPhase, child());
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
index 552bbb6770..f60158543b 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
@@ -32,7 +32,6 @@ import org.apache.doris.statistics.StatsDeriveResult;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
-import org.apache.commons.collections.CollectionUtils;
 
 import java.util.List;
 import java.util.Objects;
@@ -132,7 +131,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> 
extends PhysicalUnary<CH
     }
 
     public List<Expression> getPartitionExpressions() {
-        return CollectionUtils.isEmpty(partitionExpressions) ? 
groupByExpressions : partitionExpressions;
+        return partitionExpressions;
     }
 
     @Override
@@ -142,8 +141,9 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> 
extends PhysicalUnary<CH
 
     @Override
     public List<? extends Expression> getExpressions() {
-        // TODO: partitionExprList maybe null.
-        return new 
ImmutableList.Builder<Expression>().addAll(groupByExpressions).addAll(outputExpressions)
+        return new ImmutableList.Builder<Expression>()
+                .addAll(groupByExpressions)
+                .addAll(outputExpressions)
                 .addAll(partitionExpressions).build();
     }
 
@@ -152,7 +152,8 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> 
extends PhysicalUnary<CH
         return Utils.toSqlString("PhysicalAggregate",
                 "phase", aggPhase,
                 "outputExpr", outputExpressions,
-                "groupByExpr", groupByExpressions
+                "groupByExpr", groupByExpressions,
+                "partitionExpr", partitionExpressions
         );
     }
 
@@ -177,34 +178,34 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> 
extends PhysicalUnary<CH
 
     @Override
     public int hashCode() {
-        return Objects.hash(groupByExpressions, outputExpressions, 
partitionExpressions, aggPhase, usingStream,
-                isFinalPhase);
+        return Objects.hash(groupByExpressions, outputExpressions, 
partitionExpressions,
+                aggPhase, usingStream, isFinalPhase);
     }
 
     @Override
     public PhysicalAggregate<Plan> withChildren(List<Plan> children) {
         Preconditions.checkArgument(children.size() == 1);
-        return new PhysicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions, aggPhase,
-                usingStream, isFinalPhase, getLogicalProperties(), 
children.get(0));
+        return new PhysicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions,
+                aggPhase, usingStream, isFinalPhase, getLogicalProperties(), 
children.get(0));
     }
 
     @Override
     public PhysicalAggregate<CHILD_TYPE> 
withGroupExpression(Optional<GroupExpression> groupExpression) {
-        return new PhysicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions, aggPhase,
-                usingStream, isFinalPhase, groupExpression, 
getLogicalProperties(), child());
+        return new PhysicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions,
+                aggPhase, usingStream, isFinalPhase, groupExpression, 
getLogicalProperties(), child());
     }
 
     @Override
     public PhysicalAggregate<CHILD_TYPE> 
withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
-        return new PhysicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions, aggPhase,
-                usingStream, isFinalPhase, Optional.empty(), 
logicalProperties.get(), child());
+        return new PhysicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions,
+                aggPhase, usingStream, isFinalPhase, Optional.empty(), 
logicalProperties.get(), child());
     }
 
     @Override
     public PhysicalAggregate<CHILD_TYPE> 
withPhysicalPropertiesAndStats(PhysicalProperties physicalProperties,
             StatsDeriveResult statsDeriveResult) {
-        return new PhysicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions, aggPhase,
-                usingStream, isFinalPhase, Optional.empty(), 
getLogicalProperties(), physicalProperties,
-                statsDeriveResult, child());
+        return new PhysicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions,
+                aggPhase, usingStream, isFinalPhase,
+                Optional.empty(), getLogicalProperties(), physicalProperties, 
statsDeriveResult, child());
     }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
index 6e53265dfb..48e7d043dd 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
@@ -303,7 +303,8 @@ public class ChildOutputPropertyDeriverTest {
                 groupPlan
         );
         GroupExpression groupExpression = new GroupExpression(aggregate);
-        PhysicalProperties child = new 
PhysicalProperties(DistributionSpecReplicated.INSTANCE,
+        DistributionSpecHash childHash = new 
DistributionSpecHash(Lists.newArrayList(partition.getExprId()), 
ShuffleType.BUCKETED);
+        PhysicalProperties child = new PhysicalProperties(childHash,
                 new OrderSpec(Lists.newArrayList(new OrderKey(new 
SlotReference("ignored", IntegerType.INSTANCE), true, true))));
 
         ChildOutputPropertyDeriver deriver = new 
ChildOutputPropertyDeriver(Lists.newArrayList(child));
@@ -331,7 +332,7 @@ public class ChildOutputPropertyDeriverTest {
         );
 
         GroupExpression groupExpression = new GroupExpression(aggregate);
-        PhysicalProperties child = new 
PhysicalProperties(DistributionSpecReplicated.INSTANCE,
+        PhysicalProperties child = new 
PhysicalProperties(DistributionSpecGather.INSTANCE,
                 new OrderSpec(Lists.newArrayList(new OrderKey(new 
SlotReference("ignored", IntegerType.INSTANCE), true, true))));
 
         ChildOutputPropertyDeriver deriver = new 
ChildOutputPropertyDeriver(Lists.newArrayList(child));
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
index d583842865..6fa37387b9 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
@@ -113,65 +113,6 @@ public class AggregateDisassembleTest {
                 global.getOutputExpressions().get(1).getExprId());
     }
 
-    /**
-     * the initial plan is:
-     *   Aggregate(phase: [GLOBAL], outputExpr: [(age + 1) as key, SUM(id) as 
sum], groupByExpr: [age + 1])
-     *   +--childPlan(id, name, age)
-     * we should rewrite to:
-     *   Aggregate(phase: [GLOBAL], outputExpr: [a, SUM(b) as c], groupByExpr: 
[a])
-     *   +--Aggregate(phase: [LOCAL], outputExpr: [(age + 1) as a, SUM(id) as 
b], groupByExpr: [age + 1])
-     *       +--childPlan(id, name, age)
-     */
-    @Test
-    public void aliasGroupBy() {
-        List<Expression> groupExpressionList = Lists.newArrayList(
-                new Add(rStudent.getOutput().get(2).toSlot(), new 
IntegerLiteral(1)));
-        List<NamedExpression> outputExpressionList = Lists.newArrayList(
-                new Alias(new Add(rStudent.getOutput().get(2).toSlot(), new 
IntegerLiteral(1)), "key"),
-                new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), 
"sum"));
-        Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList, rStudent);
-
-        Plan after = rewrite(root);
-
-        Assertions.assertTrue(after instanceof LogicalUnary);
-        Assertions.assertTrue(after instanceof LogicalAggregate);
-        Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
-        LogicalAggregate<Plan> global = (LogicalAggregate) after;
-        LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0);
-        Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
-        Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
-
-        Expression localOutput0 = new 
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
-        Expression localOutput1 = new 
Sum(rStudent.getOutput().get(0).toSlot());
-        Expression localGroupBy = new 
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
-
-        Assertions.assertEquals(2, local.getOutputExpressions().size());
-        Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof 
Alias);
-        Assertions.assertEquals(localOutput0, 
local.getOutputExpressions().get(0).child(0));
-        Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof 
Alias);
-        Assertions.assertEquals(localOutput1, 
local.getOutputExpressions().get(1).child(0));
-        Assertions.assertEquals(1, local.getGroupByExpressions().size());
-        Assertions.assertEquals(localGroupBy, 
local.getGroupByExpressions().get(0));
-
-        Expression globalOutput0 = 
local.getOutputExpressions().get(0).toSlot();
-        Expression globalOutput1 = new 
Sum(local.getOutputExpressions().get(1).toSlot());
-        Expression globalGroupBy = 
local.getOutputExpressions().get(0).toSlot();
-
-        Assertions.assertEquals(2, global.getOutputExpressions().size());
-        Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof 
Alias);
-        Assertions.assertEquals(globalOutput0, 
global.getOutputExpressions().get(0).child(0));
-        Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof 
Alias);
-        Assertions.assertEquals(globalOutput1, 
global.getOutputExpressions().get(1).child(0));
-        Assertions.assertEquals(1, global.getGroupByExpressions().size());
-        Assertions.assertEquals(globalGroupBy, 
global.getGroupByExpressions().get(0));
-
-        // check id:
-        Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
-                global.getOutputExpressions().get(0).getExprId());
-        Assertions.assertEquals(outputExpressionList.get(1).getExprId(),
-                global.getOutputExpressions().get(1).getExprId());
-    }
-
     /**
      * the initial plan is:
      *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: 
[])
@@ -272,20 +213,19 @@ public class AggregateDisassembleTest {
 
     /**
      * the initial plan is:
-     *   Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age + 1) + 2) 
as c], groupByExpr: [id + 3])
+     *   Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age) + 2) as 
c], groupByExpr: [id])
      *   +-- childPlan(id, name, age)
      * we should rewrite to:
-     *   Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [(COUNT(distinct b) + 
2) as c], groupByExpr: [a])
-     *   +-- Aggregate(phase: [GLOBAL], outputExpr: [a, b], groupByExpr: [a, 
b])
-     *       +-- Aggregate(phase: [LOCAL], outputExpr: [(id + 3) as a, (age + 
1) as b], groupByExpr: [id + 3, age + 1])
+     *   Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [(COUNT(distinct age) 
+ 2) as c], groupByExpr: [id])
+     *   +-- Aggregate(phase: [GLOBAL], outputExpr: [id, age], groupByExpr: 
[id, age])
+     *       +-- Aggregate(phase: [LOCAL], outputExpr: [id, age], groupByExpr: 
[id, age])
      *           +-- childPlan(id, name, age)
      */
     @Test
     public void distinctAggregateWithGroupBy() {
-        List<Expression> groupExpressionList = Lists.newArrayList(
-                new Add(rStudent.getOutput().get(0).toSlot(), new 
IntegerLiteral(3)));
+        List<Expression> groupExpressionList = 
Lists.newArrayList(rStudent.getOutput().get(0).toSlot());
         List<NamedExpression> outputExpressionList = Lists.newArrayList(new 
Alias(
-                new Add(new Count(new 
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1)), true),
+                new Add(new Count(rStudent.getOutput().get(2).toSlot(), true),
                         new IntegerLiteral(2)), "c"));
         Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList, rStudent);
 
@@ -301,20 +241,20 @@ public class AggregateDisassembleTest {
         Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
         Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
         // check local:
-        // id + 3
-        Expression localOutput0 = new 
Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3));
-        // age + 1
-        Expression localOutput1 = new 
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
-        // id + 3
-        Expression localGroupBy0 = new 
Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3));
-        // age + 1
-        Expression localGroupBy1 = new 
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
+        // id
+        Expression localOutput0 = rStudent.getOutput().get(0).toSlot();
+        // age
+        Expression localOutput1 = rStudent.getOutput().get(2).toSlot();
+        // id
+        Expression localGroupBy0 = rStudent.getOutput().get(0).toSlot();
+        // age
+        Expression localGroupBy1 = rStudent.getOutput().get(2).toSlot();
 
         Assertions.assertEquals(2, local.getOutputExpressions().size());
-        Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof 
Alias);
-        Assertions.assertEquals(localOutput0, 
local.getOutputExpressions().get(0).child(0));
-        Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof 
Alias);
-        Assertions.assertEquals(localOutput1, 
local.getOutputExpressions().get(1).child(0));
+        Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof 
SlotReference);
+        Assertions.assertEquals(localOutput0, 
local.getOutputExpressions().get(0));
+        Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof 
SlotReference);
+        Assertions.assertEquals(localOutput1, 
local.getOutputExpressions().get(1));
         Assertions.assertEquals(2, local.getGroupByExpressions().size());
         Assertions.assertEquals(localGroupBy0, 
local.getGroupByExpressions().get(0));
         Assertions.assertEquals(localGroupBy1, 
local.getGroupByExpressions().get(1));


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to