This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 90f11ed7c1 [enhancement](Nereids) remove unnecessary exchange between
global and distinct local aggregate node (#13057)
90f11ed7c1 is described below
commit 90f11ed7c1116a48a7a10acb0a6bf50682fddd19
Author: morrySnow <[email protected]>
AuthorDate: Thu Sep 29 23:12:37 2022 +0800
[enhancement](Nereids) remove unnecessary exchange between global and
distinct local aggregate node (#13057)
Add partition info into LogicalAggregate and set it as original group
expression list of aggregate when we do aggregate disassemble with distinct
aggregate function.
---
.../nereids/jobs/cascades/CostAndEnforcerJob.java | 5 +-
.../properties/ChildOutputPropertyDeriver.java | 17 ++--
.../LogicalAggToPhysicalHashAgg.java | 4 +-
.../rules/rewrite/AggregateDisassemble.java | 75 +++++++----------
.../rules/rewrite/logical/NormalizeAggregate.java | 5 +-
.../trees/plans/logical/LogicalAggregate.java | 47 ++++++++---
.../trees/plans/physical/PhysicalAggregate.java | 33 ++++----
.../properties/ChildOutputPropertyDeriverTest.java | 5 +-
.../rewrite/logical/AggregateDisassembleTest.java | 96 ++++------------------
9 files changed, 114 insertions(+), 173 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java
index a11161f069..188acaeba5 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java
@@ -212,10 +212,9 @@ public class CostAndEnforcerJob extends Job implements
Cloneable {
private void enforce(PhysicalProperties outputProperty,
List<PhysicalProperties> requestChildrenProperty) {
PhysicalProperties requiredProperties =
context.getRequiredProperties();
-
- EnforceMissingPropertiesHelper enforceMissingPropertiesHelper
- = new EnforceMissingPropertiesHelper(context, groupExpression,
curTotalCost);
if (!outputProperty.satisfy(requiredProperties)) {
+ EnforceMissingPropertiesHelper enforceMissingPropertiesHelper
+ = new EnforceMissingPropertiesHelper(context,
groupExpression, curTotalCost);
PhysicalProperties addEnforcedProperty =
enforceMissingPropertiesHelper
.enforceProperty(outputProperty, requiredProperties);
curTotalCost = enforceMissingPropertiesHelper.getCurTotalCost();
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
index 43cedd78a4..e1ff5168ad 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
@@ -20,8 +20,6 @@ package org.apache.doris.nereids.properties;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
-import org.apache.doris.nereids.trees.expressions.ExprId;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
@@ -42,7 +40,6 @@ import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Objects;
-import java.util.stream.Collectors;
/**
* Used for property drive.
@@ -76,18 +73,14 @@ public class ChildOutputPropertyDeriver extends
PlanVisitor<PhysicalProperties,
// TODO: add distinct phase output properties
switch (agg.getAggPhase()) {
case LOCAL:
- return new
PhysicalProperties(childOutputProperty.getDistributionSpec());
case GLOBAL:
case DISTINCT_LOCAL:
- List<ExprId> columns = agg.getPartitionExpressions().stream()
- .map(SlotReference.class::cast)
- .map(SlotReference::getExprId)
- .collect(Collectors.toList());
- if (columns.isEmpty()) {
- return PhysicalProperties.GATHER;
+ DistributionSpec childSpec =
childOutputProperty.getDistributionSpec();
+ if (childSpec instanceof DistributionSpecHash) {
+ DistributionSpecHash distributionSpecHash =
(DistributionSpecHash) childSpec;
+ return new
PhysicalProperties(distributionSpecHash.withShuffleType(ShuffleType.BUCKETED));
}
- // TODO: change ENFORCED back to bucketed, when coordinator
could process bucket on agg correctly.
- return PhysicalProperties.createHash(new
DistributionSpecHash(columns, ShuffleType.BUCKETED));
+ return new
PhysicalProperties(childOutputProperty.getDistributionSpec());
case DISTINCT_GLOBAL:
default:
throw new RuntimeException("Could not derive output properties
for agg phase: " + agg.getAggPhase());
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
index 4e4d52b551..56e94eb740 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
@@ -21,8 +21,6 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
-import com.google.common.collect.ImmutableList;
-
/**
* Implementation rule that convert logical aggregation to physical hash
aggregation.
*/
@@ -33,7 +31,7 @@ public class LogicalAggToPhysicalHashAgg extends
OneImplementationRuleFactory {
// TODO: for use a function to judge whether use stream
agg.getGroupByExpressions(),
agg.getOutputExpressions(),
- ImmutableList.of(),
+ agg.getPartitionExpressions(),
agg.getAggPhase(),
false,
agg.isFinalPhase(),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
index 77135074d4..911b6735ac 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
@@ -30,17 +30,20 @@ import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.util.ExpressionUtils;
+import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Used to generate the merge agg node for distributed execution.
* NOTICE: GLOBAL output expressions' ExprId should SAME with ORIGIN output
expressions' ExprId.
+ * <pre>
* If we have a query: SELECT SUM(v1 * v2) + 1 FROM t GROUP BY k + 1
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1, Alias(SUM(v1 *
v2) + 1) #2], groupByExpr: [k + 1])
@@ -61,6 +64,7 @@ import java.util.stream.Collectors;
* +-- Aggregate(phase: [GLOBAL], outputExpr: [b, a], groupByExpr: [b, a])
* +-- Aggregate(phase: [LOCAL], outputExpr: [(k + 1) as b, (v1 * v2) as
a], groupByExpr: [k + 1, a])
* +-- childPlan
+ * </pre>
*
* TODO:
* 1. use different class represent different phase aggregate
@@ -75,37 +79,32 @@ public class AggregateDisassemble extends
OneRewriteRuleFactory {
.then(aggregate -> {
// used in secondDisassemble to transform local
expressions into global
final Map<Expression, Expression> globalOutputSMap =
Maps.newHashMap();
- // used in secondDisassemble to transform local
expressions into global
- final Map<Expression, Expression> globalGroupBySMap =
Maps.newHashMap();
- Pair<LogicalAggregate, Boolean> ret =
firstDisassemble(aggregate, globalOutputSMap,
- globalGroupBySMap);
+ Pair<LogicalAggregate<LogicalAggregate<GroupPlan>>,
Boolean> ret
+ = firstDisassemble(aggregate, globalOutputSMap);
if (!ret.second) {
return ret.first;
}
- return secondDisassemble(ret.first, globalOutputSMap,
globalGroupBySMap);
+ return secondDisassemble(ret.first, globalOutputSMap);
}).toRule(RuleType.AGGREGATE_DISASSEMBLE);
}
// only support distinct function with group by
// TODO: support distinct function without group by. (add second global
phase)
- private LogicalAggregate secondDisassemble(
- LogicalAggregate<LogicalAggregate> aggregate,
- Map<Expression, Expression> globalOutputSMap,
- Map<Expression, Expression> globalGroupBySMap) {
+ private LogicalAggregate<LogicalAggregate<LogicalAggregate<GroupPlan>>>
secondDisassemble(
+ LogicalAggregate<LogicalAggregate<GroupPlan>> aggregate,
+ Map<Expression, Expression> globalOutputSMap) {
LogicalAggregate<GroupPlan> local = aggregate.child();
// replace expression in globalOutputExprs and globalGroupByExprs
List<NamedExpression> globalOutputExprs =
local.getOutputExpressions().stream()
.map(e -> ExpressionUtils.replace(e, globalOutputSMap))
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
- List<Expression> globalGroupByExprs =
local.getGroupByExpressions().stream()
- .map(e -> ExpressionUtils.replace(e, globalGroupBySMap))
- .collect(Collectors.toList());
// generate new plan
- LogicalAggregate globalAggregate = new LogicalAggregate<>(
- globalGroupByExprs,
+ LogicalAggregate<LogicalAggregate<GroupPlan>> globalAggregate = new
LogicalAggregate<>(
+ local.getGroupByExpressions(),
globalOutputExprs,
+ Optional.of(aggregate.getGroupByExpressions()),
true,
aggregate.isNormalized(),
false,
@@ -123,11 +122,10 @@ public class AggregateDisassemble extends
OneRewriteRuleFactory {
);
}
- private Pair<LogicalAggregate, Boolean> firstDisassemble(
+ private Pair<LogicalAggregate<LogicalAggregate<GroupPlan>>, Boolean>
firstDisassemble(
LogicalAggregate<GroupPlan> aggregate,
- Map<Expression, Expression> globalOutputSMap,
- Map<Expression, Expression> globalGroupBySMap) {
- Boolean hasDistinct = Boolean.FALSE;
+ Map<Expression, Expression> globalOutputSMap) {
+ boolean hasDistinct = Boolean.FALSE;
List<NamedExpression> originOutputExprs =
aggregate.getOutputExpressions();
List<Expression> originGroupByExprs =
aggregate.getGroupByExpressions();
Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
@@ -155,18 +153,12 @@ public class AggregateDisassemble extends
OneRewriteRuleFactory {
if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
continue;
}
- if (originGroupByExpr instanceof SlotReference) {
- inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
- globalOutputSMap.put(originGroupByExpr, originGroupByExpr);
- globalGroupBySMap.put(originGroupByExpr, originGroupByExpr);
- localOutputExprs.add((SlotReference) originGroupByExpr);
- } else {
- NamedExpression localOutputExpr = new Alias(originGroupByExpr,
originGroupByExpr.toSql());
- inputSubstitutionMap.put(originGroupByExpr,
localOutputExpr.toSlot());
- globalOutputSMap.put(localOutputExpr,
localOutputExpr.toSlot());
- globalGroupBySMap.put(originGroupByExpr,
localOutputExpr.toSlot());
- localOutputExprs.add(localOutputExpr);
- }
+ // group by expr must be SlotReference or NormalizeAggregate has
bugs.
+ Preconditions.checkState(originGroupByExpr instanceof
SlotReference,
+ "normalize aggregate failed to normalize group by
expression " + originGroupByExpr.toSql());
+ inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
+ globalOutputSMap.put(originGroupByExpr, originGroupByExpr);
+ localOutputExprs.add((SlotReference) originGroupByExpr);
}
List<Expression> distinctExprsForLocalGroupBy = Lists.newArrayList();
List<NamedExpression> distinctExprsForLocalOutput =
Lists.newArrayList();
@@ -180,21 +172,12 @@ public class AggregateDisassemble extends
OneRewriteRuleFactory {
if (aggregateFunction.isDistinct()) {
hasDistinct = Boolean.TRUE;
for (Expression expr : aggregateFunction.children()) {
- if (expr instanceof SlotReference) {
- distinctExprsForLocalOutput.add((SlotReference)
expr);
- if (!inputSubstitutionMap.containsKey(expr)) {
- inputSubstitutionMap.put(expr, expr);
- globalOutputSMap.put(expr, expr);
- globalGroupBySMap.put(expr, expr);
- }
- } else {
- NamedExpression globalOutputExpr = new Alias(expr,
expr.toSql());
- distinctExprsForLocalOutput.add(globalOutputExpr);
- if (!inputSubstitutionMap.containsKey(expr)) {
- inputSubstitutionMap.put(expr,
globalOutputExpr.toSlot());
- globalOutputSMap.put(globalOutputExpr,
globalOutputExpr.toSlot());
- globalGroupBySMap.put(expr,
globalOutputExpr.toSlot());
- }
+ Preconditions.checkState(expr instanceof
SlotReference, "normalize aggregate failed to"
+ + " normalize aggregate function " +
aggregateFunction.toSql());
+ distinctExprsForLocalOutput.add((SlotReference) expr);
+ if (!inputSubstitutionMap.containsKey(expr)) {
+ inputSubstitutionMap.put(expr, expr);
+ globalOutputSMap.put(expr, expr);
}
distinctExprsForLocalGroupBy.add(expr);
}
@@ -221,7 +204,7 @@ public class AggregateDisassemble extends
OneRewriteRuleFactory {
localOutputExprs.addAll(distinctExprsForLocalOutput);
localGroupByExprs.addAll(distinctExprsForLocalGroupBy);
// 4. generate new plan
- LogicalAggregate localAggregate = new LogicalAggregate<>(
+ LogicalAggregate<GroupPlan> localAggregate = new LogicalAggregate<>(
localGroupByExprs,
localOutputExprs,
true,
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
index a888e06ce4..e1a2a832d3 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
@@ -41,8 +41,9 @@ import java.util.Set;
import java.util.stream.Collectors;
/**
- * normalize aggregate's group keys to SlotReference and generate a
LogicalProject top on LogicalAggregate
- * to hold to order of aggregate output, since aggregate output's order could
change when we do translate.
+ * normalize aggregate's group keys and AggregateFunction's child to
SlotReference
+ * and generate a LogicalProject top on LogicalAggregate to hold to order of
aggregate output,
+ * since aggregate output's order could change when we do translate.
*
* Apply this rule could simplify the processing of enforce and translate.
*
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
index 06df06b92b..3dfd2ab06d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
@@ -55,9 +55,11 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
private final boolean disassembled;
private final boolean normalized;
+ private final AggPhase aggPhase;
private final List<Expression> groupByExpressions;
private final List<NamedExpression> outputExpressions;
- private final AggPhase aggPhase;
+ // TODO: we should decide partition expression according to cost.
+ private final Optional<List<Expression>> partitionExpressions;
// use for scenes containing distinct agg
// 1. If there is LOCAL only, LOCAL is the final phase
@@ -74,7 +76,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions,
CHILD_TYPE child) {
- this(groupByExpressions, outputExpressions, false, false, true,
AggPhase.LOCAL, child);
+ this(groupByExpressions, outputExpressions, Optional.empty(), false,
false, true, AggPhase.LOCAL, child);
}
public LogicalAggregate(
@@ -85,7 +87,20 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
boolean isFinalPhase,
AggPhase aggPhase,
CHILD_TYPE child) {
- this(groupByExpressions, outputExpressions, disassembled, normalized,
isFinalPhase,
+ this(groupByExpressions, outputExpressions, Optional.empty(),
disassembled, normalized, isFinalPhase,
+ aggPhase, Optional.empty(), Optional.empty(), child);
+ }
+
+ public LogicalAggregate(
+ List<Expression> groupByExpressions,
+ List<NamedExpression> outputExpressions,
+ Optional<List<Expression>> partitionExpressions,
+ boolean disassembled,
+ boolean normalized,
+ boolean isFinalPhase,
+ AggPhase aggPhase,
+ CHILD_TYPE child) {
+ this(groupByExpressions, outputExpressions, partitionExpressions,
disassembled, normalized, isFinalPhase,
aggPhase, Optional.empty(), Optional.empty(), child);
}
@@ -95,6 +110,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
public LogicalAggregate(
List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions,
+ Optional<List<Expression>> partitionExpressions,
boolean disassembled,
boolean normalized,
boolean isFinalPhase,
@@ -105,6 +121,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
super(PlanType.LOGICAL_AGGREGATE, groupExpression, logicalProperties,
child);
this.groupByExpressions = groupByExpressions;
this.outputExpressions = outputExpressions;
+ this.partitionExpressions = partitionExpressions;
this.disassembled = disassembled;
this.normalized = normalized;
this.isFinalPhase = isFinalPhase;
@@ -119,6 +136,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
return outputExpressions;
}
+ public List<Expression> getPartitionExpressions() {
+ return partitionExpressions.orElse(groupByExpressions);
+ }
+
public AggPhase getAggPhase() {
return aggPhase;
}
@@ -141,7 +162,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
- return visitor.visitLogicalAggregate((LogicalAggregate<Plan>) this,
context);
+ return visitor.visitLogicalAggregate(this, context);
}
@Override
@@ -177,6 +198,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
LogicalAggregate that = (LogicalAggregate) o;
return Objects.equals(groupByExpressions, that.groupByExpressions)
&& Objects.equals(outputExpressions, that.outputExpressions)
+ && Objects.equals(partitionExpressions,
that.partitionExpressions)
&& aggPhase == that.aggPhase
&& disassembled == that.disassembled
&& normalized == that.normalized
@@ -185,31 +207,34 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
@Override
public int hashCode() {
- return Objects.hash(groupByExpressions, outputExpressions, aggPhase,
normalized, disassembled, isFinalPhase);
+ return Objects.hash(groupByExpressions, outputExpressions,
partitionExpressions,
+ aggPhase, normalized, disassembled, isFinalPhase);
}
@Override
public LogicalAggregate<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
- return new LogicalAggregate<>(groupByExpressions, outputExpressions,
+ return new LogicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions,
disassembled, normalized, isFinalPhase, aggPhase,
children.get(0));
}
@Override
public LogicalAggregate<Plan>
withGroupExpression(Optional<GroupExpression> groupExpression) {
- return new LogicalAggregate<>(groupByExpressions, outputExpressions,
disassembled, normalized, isFinalPhase,
- aggPhase, groupExpression,
Optional.of(getLogicalProperties()), children.get(0));
+ return new LogicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions,
+ disassembled, normalized, isFinalPhase, aggPhase,
+ groupExpression, Optional.of(getLogicalProperties()),
children.get(0));
}
@Override
public LogicalAggregate<Plan>
withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
- return new LogicalAggregate<>(groupByExpressions, outputExpressions,
disassembled, normalized, isFinalPhase,
- aggPhase, Optional.empty(), logicalProperties,
children.get(0));
+ return new LogicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions,
+ disassembled, normalized, isFinalPhase, aggPhase,
+ Optional.empty(), logicalProperties, children.get(0));
}
public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression>
groupByExprList,
List<NamedExpression> outputExpressionList) {
- return new LogicalAggregate<>(groupByExprList, outputExpressionList,
+ return new LogicalAggregate<>(groupByExprList, outputExpressionList,
partitionExpressions,
disassembled, normalized, isFinalPhase, aggPhase, child());
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
index 552bbb6770..f60158543b 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
@@ -32,7 +32,6 @@ import org.apache.doris.statistics.StatsDeriveResult;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
-import org.apache.commons.collections.CollectionUtils;
import java.util.List;
import java.util.Objects;
@@ -132,7 +131,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan>
extends PhysicalUnary<CH
}
public List<Expression> getPartitionExpressions() {
- return CollectionUtils.isEmpty(partitionExpressions) ?
groupByExpressions : partitionExpressions;
+ return partitionExpressions;
}
@Override
@@ -142,8 +141,9 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan>
extends PhysicalUnary<CH
@Override
public List<? extends Expression> getExpressions() {
- // TODO: partitionExprList maybe null.
- return new
ImmutableList.Builder<Expression>().addAll(groupByExpressions).addAll(outputExpressions)
+ return new ImmutableList.Builder<Expression>()
+ .addAll(groupByExpressions)
+ .addAll(outputExpressions)
.addAll(partitionExpressions).build();
}
@@ -152,7 +152,8 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan>
extends PhysicalUnary<CH
return Utils.toSqlString("PhysicalAggregate",
"phase", aggPhase,
"outputExpr", outputExpressions,
- "groupByExpr", groupByExpressions
+ "groupByExpr", groupByExpressions,
+ "partitionExpr", partitionExpressions
);
}
@@ -177,34 +178,34 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan>
extends PhysicalUnary<CH
@Override
public int hashCode() {
- return Objects.hash(groupByExpressions, outputExpressions,
partitionExpressions, aggPhase, usingStream,
- isFinalPhase);
+ return Objects.hash(groupByExpressions, outputExpressions,
partitionExpressions,
+ aggPhase, usingStream, isFinalPhase);
}
@Override
public PhysicalAggregate<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
- return new PhysicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions, aggPhase,
- usingStream, isFinalPhase, getLogicalProperties(),
children.get(0));
+ return new PhysicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions,
+ aggPhase, usingStream, isFinalPhase, getLogicalProperties(),
children.get(0));
}
@Override
public PhysicalAggregate<CHILD_TYPE>
withGroupExpression(Optional<GroupExpression> groupExpression) {
- return new PhysicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions, aggPhase,
- usingStream, isFinalPhase, groupExpression,
getLogicalProperties(), child());
+ return new PhysicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions,
+ aggPhase, usingStream, isFinalPhase, groupExpression,
getLogicalProperties(), child());
}
@Override
public PhysicalAggregate<CHILD_TYPE>
withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
- return new PhysicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions, aggPhase,
- usingStream, isFinalPhase, Optional.empty(),
logicalProperties.get(), child());
+ return new PhysicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions,
+ aggPhase, usingStream, isFinalPhase, Optional.empty(),
logicalProperties.get(), child());
}
@Override
public PhysicalAggregate<CHILD_TYPE>
withPhysicalPropertiesAndStats(PhysicalProperties physicalProperties,
StatsDeriveResult statsDeriveResult) {
- return new PhysicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions, aggPhase,
- usingStream, isFinalPhase, Optional.empty(),
getLogicalProperties(), physicalProperties,
- statsDeriveResult, child());
+ return new PhysicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions,
+ aggPhase, usingStream, isFinalPhase,
+ Optional.empty(), getLogicalProperties(), physicalProperties,
statsDeriveResult, child());
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
index 6e53265dfb..48e7d043dd 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
@@ -303,7 +303,8 @@ public class ChildOutputPropertyDeriverTest {
groupPlan
);
GroupExpression groupExpression = new GroupExpression(aggregate);
- PhysicalProperties child = new
PhysicalProperties(DistributionSpecReplicated.INSTANCE,
+ DistributionSpecHash childHash = new
DistributionSpecHash(Lists.newArrayList(partition.getExprId()),
ShuffleType.BUCKETED);
+ PhysicalProperties child = new PhysicalProperties(childHash,
new OrderSpec(Lists.newArrayList(new OrderKey(new
SlotReference("ignored", IntegerType.INSTANCE), true, true))));
ChildOutputPropertyDeriver deriver = new
ChildOutputPropertyDeriver(Lists.newArrayList(child));
@@ -331,7 +332,7 @@ public class ChildOutputPropertyDeriverTest {
);
GroupExpression groupExpression = new GroupExpression(aggregate);
- PhysicalProperties child = new
PhysicalProperties(DistributionSpecReplicated.INSTANCE,
+ PhysicalProperties child = new
PhysicalProperties(DistributionSpecGather.INSTANCE,
new OrderSpec(Lists.newArrayList(new OrderKey(new
SlotReference("ignored", IntegerType.INSTANCE), true, true))));
ChildOutputPropertyDeriver deriver = new
ChildOutputPropertyDeriver(Lists.newArrayList(child));
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
index d583842865..6fa37387b9 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
@@ -113,65 +113,6 @@ public class AggregateDisassembleTest {
global.getOutputExpressions().get(1).getExprId());
}
- /**
- * the initial plan is:
- * Aggregate(phase: [GLOBAL], outputExpr: [(age + 1) as key, SUM(id) as
sum], groupByExpr: [age + 1])
- * +--childPlan(id, name, age)
- * we should rewrite to:
- * Aggregate(phase: [GLOBAL], outputExpr: [a, SUM(b) as c], groupByExpr:
[a])
- * +--Aggregate(phase: [LOCAL], outputExpr: [(age + 1) as a, SUM(id) as
b], groupByExpr: [age + 1])
- * +--childPlan(id, name, age)
- */
- @Test
- public void aliasGroupBy() {
- List<Expression> groupExpressionList = Lists.newArrayList(
- new Add(rStudent.getOutput().get(2).toSlot(), new
IntegerLiteral(1)));
- List<NamedExpression> outputExpressionList = Lists.newArrayList(
- new Alias(new Add(rStudent.getOutput().get(2).toSlot(), new
IntegerLiteral(1)), "key"),
- new Alias(new Sum(rStudent.getOutput().get(0).toSlot()),
"sum"));
- Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList, rStudent);
-
- Plan after = rewrite(root);
-
- Assertions.assertTrue(after instanceof LogicalUnary);
- Assertions.assertTrue(after instanceof LogicalAggregate);
- Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
- LogicalAggregate<Plan> global = (LogicalAggregate) after;
- LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0);
- Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
- Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
-
- Expression localOutput0 = new
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
- Expression localOutput1 = new
Sum(rStudent.getOutput().get(0).toSlot());
- Expression localGroupBy = new
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
-
- Assertions.assertEquals(2, local.getOutputExpressions().size());
- Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof
Alias);
- Assertions.assertEquals(localOutput0,
local.getOutputExpressions().get(0).child(0));
- Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof
Alias);
- Assertions.assertEquals(localOutput1,
local.getOutputExpressions().get(1).child(0));
- Assertions.assertEquals(1, local.getGroupByExpressions().size());
- Assertions.assertEquals(localGroupBy,
local.getGroupByExpressions().get(0));
-
- Expression globalOutput0 =
local.getOutputExpressions().get(0).toSlot();
- Expression globalOutput1 = new
Sum(local.getOutputExpressions().get(1).toSlot());
- Expression globalGroupBy =
local.getOutputExpressions().get(0).toSlot();
-
- Assertions.assertEquals(2, global.getOutputExpressions().size());
- Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof
Alias);
- Assertions.assertEquals(globalOutput0,
global.getOutputExpressions().get(0).child(0));
- Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof
Alias);
- Assertions.assertEquals(globalOutput1,
global.getOutputExpressions().get(1).child(0));
- Assertions.assertEquals(1, global.getGroupByExpressions().size());
- Assertions.assertEquals(globalGroupBy,
global.getGroupByExpressions().get(0));
-
- // check id:
- Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
- global.getOutputExpressions().get(0).getExprId());
- Assertions.assertEquals(outputExpressionList.get(1).getExprId(),
- global.getOutputExpressions().get(1).getExprId());
- }
-
/**
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr:
[])
@@ -272,20 +213,19 @@ public class AggregateDisassembleTest {
/**
* the initial plan is:
- * Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age + 1) + 2)
as c], groupByExpr: [id + 3])
+ * Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age) + 2) as
c], groupByExpr: [id])
* +-- childPlan(id, name, age)
* we should rewrite to:
- * Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [(COUNT(distinct b) +
2) as c], groupByExpr: [a])
- * +-- Aggregate(phase: [GLOBAL], outputExpr: [a, b], groupByExpr: [a,
b])
- * +-- Aggregate(phase: [LOCAL], outputExpr: [(id + 3) as a, (age +
1) as b], groupByExpr: [id + 3, age + 1])
+ * Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [(COUNT(distinct age)
+ 2) as c], groupByExpr: [id])
+ * +-- Aggregate(phase: [GLOBAL], outputExpr: [id, age], groupByExpr:
[id, age])
+ * +-- Aggregate(phase: [LOCAL], outputExpr: [id, age], groupByExpr:
[id, age])
* +-- childPlan(id, name, age)
*/
@Test
public void distinctAggregateWithGroupBy() {
- List<Expression> groupExpressionList = Lists.newArrayList(
- new Add(rStudent.getOutput().get(0).toSlot(), new
IntegerLiteral(3)));
+ List<Expression> groupExpressionList =
Lists.newArrayList(rStudent.getOutput().get(0).toSlot());
List<NamedExpression> outputExpressionList = Lists.newArrayList(new
Alias(
- new Add(new Count(new
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1)), true),
+ new Add(new Count(rStudent.getOutput().get(2).toSlot(), true),
new IntegerLiteral(2)), "c"));
Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList, rStudent);
@@ -301,20 +241,20 @@ public class AggregateDisassembleTest {
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
// check local:
- // id + 3
- Expression localOutput0 = new
Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3));
- // age + 1
- Expression localOutput1 = new
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
- // id + 3
- Expression localGroupBy0 = new
Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3));
- // age + 1
- Expression localGroupBy1 = new
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
+ // id
+ Expression localOutput0 = rStudent.getOutput().get(0).toSlot();
+ // age
+ Expression localOutput1 = rStudent.getOutput().get(2).toSlot();
+ // id
+ Expression localGroupBy0 = rStudent.getOutput().get(0).toSlot();
+ // age
+ Expression localGroupBy1 = rStudent.getOutput().get(2).toSlot();
Assertions.assertEquals(2, local.getOutputExpressions().size());
- Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof
Alias);
- Assertions.assertEquals(localOutput0,
local.getOutputExpressions().get(0).child(0));
- Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof
Alias);
- Assertions.assertEquals(localOutput1,
local.getOutputExpressions().get(1).child(0));
+ Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof
SlotReference);
+ Assertions.assertEquals(localOutput0,
local.getOutputExpressions().get(0));
+ Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof
SlotReference);
+ Assertions.assertEquals(localOutput1,
local.getOutputExpressions().get(1));
Assertions.assertEquals(2, local.getGroupByExpressions().size());
Assertions.assertEquals(localGroupBy0,
local.getGroupByExpressions().get(0));
Assertions.assertEquals(localGroupBy1,
local.getGroupByExpressions().get(1));
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]