This is an automated email from the ASF dual-hosted git repository.
huajianlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 593a916ae6 [feature](nereids) split AggregateDisassemble into two
rules (#14611)
593a916ae6 is described below
commit 593a916ae64e0ffd8884651ce4ec8279dd755658
Author: yinzhijian <[email protected]>
AuthorDate: Wed Nov 30 14:02:42 2022 +0800
[feature](nereids) split AggregateDisassemble into two rules (#14611)
# Proposed changes
Issue Number: close #14280
## Problem summary
The AggregateDisassemble rule is refactored and split into two rules, which
are not dependent on each other.
1. AggregateDisassemble splits the agg into two phases: Local, Global.
1.1. For count function, the implementation is as
follows:distinct_multi_count(update)+ distinct_multi_count(merge)
2. DistinctAggregateDisassemble splits the agg into 4 stages: Local,
Global, Distinct Local, Distinct GLobal.
2.1. For count function, the implementation is as
follows:distinct_multi_count(update)+ distinct_multi_count(merge)+sum(update)+
sum(merge)
---
.../glue/translator/ExpressionTranslator.java | 12 +-
.../glue/translator/PhysicalPlanTranslator.java | 44 ++---
.../properties/ChildOutputPropertyDeriver.java | 3 +-
.../nereids/properties/RequestPropertyDeriver.java | 32 +++-
.../org/apache/doris/nereids/rules/RuleSet.java | 2 +
.../org/apache/doris/nereids/rules/RuleType.java | 1 +
.../doris/nereids/rules/analysis/BindFunction.java | 3 +-
.../rules/rewrite/AggregateDisassemble.java | 110 ++----------
.../rewrite/DistinctAggregateDisassemble.java | 190 +++++++++++++++++++++
.../functions/agg/AggregateFunction.java | 10 +-
.../expressions/functions/agg/AggregateParam.java | 43 +++--
.../trees/plans/logical/LogicalAggregate.java | 33 ++++
.../rewrite/logical/AggregateDisassembleTest.java | 182 +++++++++++---------
.../trees/expressions/ExpressionEqualsTest.java | 6 +-
14 files changed, 435 insertions(+), 236 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
index cf9532253b..1863da7c7f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
@@ -68,6 +68,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
+import org.apache.doris.nereids.trees.plans.AggPhase;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.thrift.TFunctionBinaryType;
@@ -289,8 +290,17 @@ public class ExpressionTranslator extends
DefaultExpressionVisitor<Expr, PlanTra
: NullableMode.ALWAYS_NOT_NULLABLE;
boolean isAnalyticFunction = false;
+ String functionName = function.isDistinct() ? "MULTI_DISTINCT_" +
function.getName() : function.getName();
+ if (function.getAggregateParam().aggPhase == AggPhase.DISTINCT_LOCAL
+ || function.getAggregateParam().aggPhase ==
AggPhase.DISTINCT_GLOBAL) {
+ if (function.getName().equalsIgnoreCase("count")) {
+ functionName = "SUM";
+ } else {
+ functionName = function.getName();
+ }
+ }
org.apache.doris.catalog.AggregateFunction catalogFunction = new
org.apache.doris.catalog.AggregateFunction(
- new FunctionName(function.getName()), argTypes,
+ new FunctionName(functionName), argTypes,
function.getDataType().toCatalogDataType(),
function.getIntermediateTypes().toCatalogDataType(),
function.hasVarArguments(),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index 235f0265b6..417b0b049e 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -159,7 +159,6 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
/**
* Translate Agg.
- * todo: support DISTINCT
*/
@Override
public PlanFragment visitPhysicalAggregate(
@@ -201,44 +200,23 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
List<Expr> execPartitionExpressions = partitionExpressionList.stream()
.map(e -> ExpressionTranslator.translate(e,
context)).collect(Collectors.toList());
DataPartition mergePartition = DataPartition.UNPARTITIONED;
- if (CollectionUtils.isNotEmpty(execPartitionExpressions)) {
+ if (CollectionUtils.isNotEmpty(execPartitionExpressions)
+ && aggregate.getAggPhase() != AggPhase.DISTINCT_GLOBAL) {
mergePartition =
DataPartition.hashPartitioned(execPartitionExpressions);
}
// 3. generate output tuple
List<Slot> slotList = Lists.newArrayList();
TupleDescriptor outputTupleDesc;
- if (aggregate.getAggPhase() == AggPhase.LOCAL
- || (aggregate.getAggPhase() == AggPhase.GLOBAL &&
aggregate.isFinalPhase())
- || aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
- slotList.addAll(groupSlotList);
- slotList.addAll(aggFunctionOutput);
- outputTupleDesc = generateTupleDesc(slotList, null, context);
- } else {
- // In the distinct agg scenario, global shares local's desc
- AggregationNode localAggNode;
- if (inputPlanFragment.getPlanRoot() instanceof ExchangeNode) {
- localAggNode = (AggregationNode)
inputPlanFragment.getPlanRoot().getChild(0);
- } else {
- // If the group by expr hits the partition key, there may be
no exchange node
- localAggNode = (AggregationNode)
inputPlanFragment.getPlanRoot();
- }
- outputTupleDesc = localAggNode.getAggInfo().getOutputTupleDesc();
- }
+ slotList.addAll(groupSlotList);
+ slotList.addAll(aggFunctionOutput);
+ outputTupleDesc = generateTupleDesc(slotList, null, context);
- // TODO: move setMergeForNereids to
ExpressionTranslator.visitAggregateFunction
- if (aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
- for (FunctionCallExpr execAggregateFunction :
execAggregateFunctions) {
- if (!execAggregateFunction.isDistinct()) {
- execAggregateFunction.setMergeForNereids(true);
- }
- }
- }
AggregateInfo aggInfo = AggregateInfo.create(execGroupingExpressions,
execAggregateFunctions, outputTupleDesc,
outputTupleDesc, aggregate.getAggPhase().toExec());
AggregationNode aggregationNode = new
AggregationNode(context.nextPlanNodeId(),
inputPlanFragment.getPlanRoot(), aggInfo);
- if (!aggregate.isFinalPhase()) {
+ if (!aggregate.getAggPhase().isGlobal() && !aggregate.isFinalPhase()) {
aggregationNode.unsetNeedsFinalize();
}
PlanFragment currentFragment = inputPlanFragment;
@@ -247,8 +225,12 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
aggregationNode.setUseStreamingPreagg(aggregate.isUsingStream());
aggregationNode.setIntermediateTuple();
break;
- case GLOBAL:
case DISTINCT_LOCAL:
+ case GLOBAL:
+ case DISTINCT_GLOBAL:
+ if (aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
+ aggregationNode.setIntermediateTuple();
+ }
if (currentFragment.getPlanRoot() instanceof ExchangeNode) {
ExchangeNode exchangeNode = (ExchangeNode)
currentFragment.getPlanRoot();
currentFragment = new
PlanFragment(context.nextFragmentId(), exchangeNode, mergePartition);
@@ -257,7 +239,9 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
inputPlanFragment.setDestination(exchangeNode);
context.addPlanFragment(currentFragment);
}
- currentFragment.updateDataPartition(mergePartition);
+ if (aggregate.getAggPhase() != AggPhase.DISTINCT_LOCAL) {
+ currentFragment.updateDataPartition(mergePartition);
+ }
break;
default:
throw new RuntimeException("Unsupported yet");
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
index b4772dad73..86cf47386e 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
@@ -79,18 +79,17 @@ public class ChildOutputPropertyDeriver extends
PlanVisitor<PhysicalProperties,
public PhysicalProperties visitPhysicalAggregate(PhysicalAggregate<?
extends Plan> agg, PlanContext context) {
Preconditions.checkState(childrenOutputProperties.size() == 1);
PhysicalProperties childOutputProperty =
childrenOutputProperties.get(0);
- // TODO: add distinct phase output properties
switch (agg.getAggPhase()) {
case LOCAL:
case GLOBAL:
case DISTINCT_LOCAL:
+ case DISTINCT_GLOBAL:
DistributionSpec childSpec =
childOutputProperty.getDistributionSpec();
if (childSpec instanceof DistributionSpecHash) {
DistributionSpecHash distributionSpecHash =
(DistributionSpecHash) childSpec;
return new
PhysicalProperties(distributionSpecHash.withShuffleType(ShuffleType.BUCKETED));
}
return new
PhysicalProperties(childOutputProperty.getDistributionSpec());
- case DISTINCT_GLOBAL:
default:
throw new RuntimeException("Could not derive output properties
for agg phase: " + agg.getAggPhase());
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
index 96379b6339..8b0fc1c95d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
@@ -24,7 +24,9 @@ import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.AggPhase;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
@@ -36,9 +38,11 @@ import
org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.JoinUtils;
+import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.List;
+import java.util.Set;
import java.util.stream.Collectors;
/**
@@ -94,10 +98,18 @@ public class RequestPropertyDeriver extends
PlanVisitor<Void, PlanContext> {
}
// 2. second phase agg, need to return shuffle with partition key
List<Expression> partitionExpressions = agg.getPartitionExpressions();
- if (partitionExpressions.isEmpty()) {
+ if (partitionExpressions.isEmpty() && agg.getAggPhase() !=
AggPhase.DISTINCT_LOCAL) {
addRequestPropertyToChildren(PhysicalProperties.GATHER);
return null;
}
+ if (agg.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
+ // use slots in distinct agg as shuffle slots
+ List<ExprId> shuffleSlots =
extractFromDistinctFunction(agg.getOutputExpressions());
+ Preconditions.checkState(!shuffleSlots.isEmpty());
+ addRequestPropertyToChildren(
+ PhysicalProperties.createHash(new
DistributionSpecHash(shuffleSlots, ShuffleType.AGGREGATE)));
+ return null;
+ }
// TODO: when parent is a join node,
// use requestPropertyFromParent to keep column order as join to
avoid shuffle again.
if
(partitionExpressions.stream().allMatch(SlotReference.class::isInstance)) {
@@ -170,5 +182,23 @@ public class RequestPropertyDeriver extends
PlanVisitor<Void, PlanContext> {
private void addRequestPropertyToChildren(PhysicalProperties...
physicalProperties) {
requestPropertyToChildren.add(Lists.newArrayList(physicalProperties));
}
+
+ private List<ExprId> extractFromDistinctFunction(List<NamedExpression>
outputExpression) {
+ List<ExprId> exprIds = Lists.newArrayList();
+ for (NamedExpression originOutputExpr : outputExpression) {
+ Set<AggregateFunction> aggregateFunctions
+ =
originOutputExpr.collect(AggregateFunction.class::isInstance);
+ for (AggregateFunction aggregateFunction : aggregateFunctions) {
+ if (aggregateFunction.isDistinct()) {
+ for (Expression expr : aggregateFunction.children()) {
+ Preconditions.checkState(expr instanceof
SlotReference, "normalize aggregate failed to"
+ + " normalize aggregate function " +
aggregateFunction.toSql());
+ exprIds.add(((SlotReference) expr).getExprId());
+ }
+ }
+ }
+ }
+ return exprIds;
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
index 87ed6b1bb1..d3212f9214 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
@@ -41,6 +41,7 @@ import
org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickS
import
org.apache.doris.nereids.rules.implementation.LogicalTVFRelationToPhysicalTVFRelation;
import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
+import org.apache.doris.nereids.rules.rewrite.DistinctAggregateDisassemble;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateOuterJoin;
import org.apache.doris.nereids.rules.rewrite.logical.MergeFilters;
import org.apache.doris.nereids.rules.rewrite.logical.MergeLimits;
@@ -73,6 +74,7 @@ public class RuleSet {
.add(SemiJoinSemiJoinTranspose.INSTANCE)
.add(SemiJoinSemiJoinTransposeProject.INSTANCE)
.add(new AggregateDisassemble())
+ .add(new DistinctAggregateDisassemble())
.add(new PushdownFilterThroughProject())
.add(new MergeProjects())
.build();
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 32bb60b8b3..ec1698a068 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -77,6 +77,7 @@ public enum RuleType {
NORMALIZE_AGGREGATE(RuleTypeClass.REWRITE),
NORMALIZE_REPEAT(RuleTypeClass.REWRITE),
AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
+ DISTINCT_AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
COLUMN_PRUNE_PROJECTION(RuleTypeClass.REWRITE),
ELIMINATE_UNNECESSARY_PROJECT(RuleTypeClass.REWRITE),
LOGICAL_SUB_QUERY_ALIAS_TO_LOGICAL_PROJECT(RuleTypeClass.REWRITE),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
index 78d20462da..ac0c1238e0 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
@@ -37,6 +37,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import
org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction;
import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.trees.plans.AggPhase;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
@@ -177,7 +178,7 @@ public class BindFunction implements AnalysisRuleFactory {
}
if (arguments.size() == 1) {
AggregateParam aggregateParam = new AggregateParam(
- unboundFunction.isDistinct(), true, false);
+ unboundFunction.isDistinct(), true,
AggPhase.LOCAL, false);
return new Count(aggregateParam,
unboundFunction.getArguments().get(0));
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
index a09066c267..4d0b55ebe3 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
@@ -17,7 +17,6 @@
package org.apache.doris.nereids.rules.rewrite;
-import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
@@ -37,7 +36,6 @@ import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@@ -53,84 +51,23 @@ import java.util.stream.Collectors;
* Aggregate(phase: [GLOBAL], outputExpr: [Alias(b) #1, Alias(SUM(a) + 1)
#2], groupByExpr: [b])
* +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as
b], groupByExpr: [k + 1])
* +-- childPlan
- *
- * Distinct Agg With Group By Processing:
- * If we have a query: SELECT count(distinct v1 * v2) + 1 FROM t GROUP BY k + 1
- * the initial plan is:
- * Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1,
Alias(COUNT(distinct v1 * v2) + 1) #2]
- * , groupByExpr: [k + 1])
- * +-- childPlan
- * we should rewrite to:
- * Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [Alias(b) #1,
Alias(COUNT(distinct a) + 1) #2], groupByExpr: [b])
- * +-- Aggregate(phase: [GLOBAL], outputExpr: [b, a], groupByExpr: [b, a])
- * +-- Aggregate(phase: [LOCAL], outputExpr: [(k + 1) as b, (v1 * v2) as
a], groupByExpr: [k + 1, a])
- * +-- childPlan
* </pre>
*
* TODO:
- * 1. use different class represent different phase aggregate
- * 2. if instance count is 1, shouldn't disassemble the agg plan
+ * 1. if instance count is 1, shouldn't disassemble the agg plan
*/
public class AggregateDisassemble extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalAggregate()
- .whenNot(LogicalAggregate::isDisassembled)
- .then(aggregate -> {
- // used in secondDisassemble to transform local
expressions into global
- Map<Expression, Expression> globalOutputSMap =
Maps.newHashMap();
- Pair<LogicalAggregate<LogicalAggregate<GroupPlan>>,
Boolean> result
- = disassembleAggregateFunction(aggregate,
globalOutputSMap);
- LogicalAggregate<LogicalAggregate<GroupPlan>> newPlan =
result.first;
- boolean hasDistinct = result.second;
- if (!hasDistinct) {
- return newPlan;
- }
- return disassembleDistinct(newPlan, globalOutputSMap);
- }).toRule(RuleType.AGGREGATE_DISASSEMBLE);
+ .when(LogicalAggregate::isFinalPhase)
+ .when(LogicalAggregate::isLocal)
+
.then(this::disassembleAggregateFunction).toRule(RuleType.AGGREGATE_DISASSEMBLE);
}
- // only support distinct function with group by
- // TODO: support distinct function without group by. (add second global
phase)
- private LogicalAggregate<LogicalAggregate<LogicalAggregate<GroupPlan>>>
disassembleDistinct(
- LogicalAggregate<LogicalAggregate<GroupPlan>> aggregate,
- Map<Expression, Expression> globalOutputSMap) {
- LogicalAggregate<GroupPlan> localDistinct = aggregate.child();
- // replace expression in globalOutputExprs and globalGroupByExprs
- List<NamedExpression> globalOutputExprs =
localDistinct.getOutputExpressions().stream()
- .map(e -> ExpressionUtils.replace(e, globalOutputSMap))
- .map(NamedExpression.class::cast)
- .collect(Collectors.toList());
-
- // generate new plan
- LogicalAggregate<LogicalAggregate<GroupPlan>> globalDistinct = new
LogicalAggregate<>(
- localDistinct.getGroupByExpressions(),
- globalOutputExprs,
- Optional.of(aggregate.getGroupByExpressions()),
- true,
- aggregate.isNormalized(),
- false,
- AggPhase.GLOBAL,
- aggregate.getSourceRepeat(),
- localDistinct
- );
- return new LogicalAggregate<>(
- aggregate.getGroupByExpressions(),
- aggregate.getOutputExpressions(),
- true,
- aggregate.isNormalized(),
- true,
- AggPhase.DISTINCT_LOCAL,
- aggregate.getSourceRepeat(),
- globalDistinct
- );
- }
-
- private Pair<LogicalAggregate<LogicalAggregate<GroupPlan>>, Boolean>
disassembleAggregateFunction(
- LogicalAggregate<GroupPlan> aggregate,
- Map<Expression, Expression> globalOutputSMap) {
- boolean hasDistinct = Boolean.FALSE;
+ private LogicalAggregate<LogicalAggregate<GroupPlan>>
disassembleAggregateFunction(
+ LogicalAggregate<GroupPlan> aggregate) {
List<NamedExpression> originOutputExprs =
aggregate.getOutputExpressions();
List<Expression> originGroupByExprs =
aggregate.getGroupByExpressions();
Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
@@ -162,11 +99,8 @@ public class AggregateDisassemble extends
OneRewriteRuleFactory {
Preconditions.checkState(originGroupByExpr instanceof
SlotReference,
"normalize aggregate failed to normalize group by
expression " + originGroupByExpr.toSql());
inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
- globalOutputSMap.put(originGroupByExpr, originGroupByExpr);
localOutputExprs.add((SlotReference) originGroupByExpr);
}
- List<Expression> distinctExprsForLocalGroupBy = Lists.newArrayList();
- List<NamedExpression> distinctExprsForLocalOutput =
Lists.newArrayList();
for (NamedExpression originOutputExpr : originOutputExprs) {
Set<AggregateFunction> aggregateFunctions
=
originOutputExpr.collect(AggregateFunction.class::isInstance);
@@ -174,39 +108,19 @@ public class AggregateDisassemble extends
OneRewriteRuleFactory {
if (inputSubstitutionMap.containsKey(aggregateFunction)) {
continue;
}
- if (aggregateFunction.isDistinct()) {
- hasDistinct = Boolean.TRUE;
- for (Expression expr : aggregateFunction.children()) {
- Preconditions.checkState(expr instanceof
SlotReference, "normalize aggregate failed to"
- + " normalize aggregate function " +
aggregateFunction.toSql());
- distinctExprsForLocalOutput.add((SlotReference) expr);
- if (!inputSubstitutionMap.containsKey(expr)) {
- inputSubstitutionMap.put(expr, expr);
- globalOutputSMap.put(expr, expr);
- }
- distinctExprsForLocalGroupBy.add(expr);
- }
- continue;
- }
-
AggregateFunction localAggregateFunction =
aggregateFunction.withAggregateParam(
aggregateFunction.getAggregateParam()
- .withDistinct(false)
- .withGlobalAndDisassembled(false, true)
+ .withPhaseAndDisassembled(false,
AggPhase.LOCAL, true)
);
NamedExpression localOutputExpr = new
Alias(localAggregateFunction, aggregateFunction.toSql());
AggregateFunction substitutionValue = aggregateFunction
// save the origin input types to the global aggregate
functions
.withAggregateParam(aggregateFunction.getAggregateParam()
- .withDistinct(false)
- .withGlobalAndDisassembled(true, true))
+ .withPhaseAndDisassembled(true,
AggPhase.GLOBAL, true))
.withChildren(Lists.newArrayList(localOutputExpr.toSlot()));
inputSubstitutionMap.put(aggregateFunction, substitutionValue);
- // because we use local output exprs to generate global output
in disassembleDistinct,
- // so we must use localAggregateFunction as key
- globalOutputSMap.put(localAggregateFunction,
substitutionValue);
localOutputExprs.add(localOutputExpr);
}
}
@@ -218,10 +132,6 @@ public class AggregateDisassemble extends
OneRewriteRuleFactory {
.collect(Collectors.toList());
List<Expression> globalGroupByExprs = localGroupByExprs.stream()
.map(e -> ExpressionUtils.replace(e,
inputSubstitutionMap)).collect(Collectors.toList());
- // To avoid repeated substitution of distinct expressions,
- // here the expressions are put into the local after the substitution
is completed
- localOutputExprs.addAll(distinctExprsForLocalOutput);
- localGroupByExprs.addAll(distinctExprsForLocalGroupBy);
// 4. generate new plan
LogicalAggregate<GroupPlan> localAggregate = new LogicalAggregate<>(
localGroupByExprs,
@@ -233,7 +143,7 @@ public class AggregateDisassemble extends
OneRewriteRuleFactory {
aggregate.getSourceRepeat(),
aggregate.child()
);
- return Pair.of(new LogicalAggregate<>(
+ return new LogicalAggregate<>(
globalGroupByExprs,
globalOutputExprs,
true,
@@ -242,6 +152,6 @@ public class AggregateDisassemble extends
OneRewriteRuleFactory {
AggPhase.GLOBAL,
aggregate.getSourceRepeat(),
localAggregate
- ), hasDistinct);
+ );
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateDisassemble.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateDisassemble.java
new file mode 100644
index 0000000000..1d182cf346
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateDisassemble.java
@@ -0,0 +1,190 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.plans.AggPhase;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Used to generate the merge agg node for distributed execution.
+ * NOTICE: DISTINCT GLOBAL output expressions' ExprId should SAME with ORIGIN
output expressions' ExprId.
+ * <pre>
+ * If we have a query: SELECT COUNT(distinct v1 * v2) + 1 FROM t
+ * the initial plan is:
+ * +-- Aggregate(phase: [LOCAL], outputExpr: [Alias(COUNT(distinct v1 * v2)
+ 1) #2])
+ * +-- childPlan
+ * we should rewrite to:
+ * Aggregate(phase: [GLOBAL DISTINCT], outputExpr: [Alias(SUM(c) + 1) #2])
+ * +-- Aggregate(phase: [LOCAL DISTINCT], outputExpr: [SUM(b) as c] )
+ * +-- Aggregate(phase: [GLOBAL], outputExpr: [COUNT(distinct a) as b])
+ * +-- Aggregate(phase: [LOCAL], outputExpr: [COUNT(distinct v1 *
v2) as a])
+ * +-- childPlan
+ * </pre>
+ */
+public class DistinctAggregateDisassemble extends OneRewriteRuleFactory {
+
+ @Override
+ public Rule build() {
+ return logicalAggregate()
+ .when(LogicalAggregate::needDistinctDisassemble)
+
.then(this::disassembleAggregateFunction).toRule(RuleType.DISTINCT_AGGREGATE_DISASSEMBLE);
+ }
+
+ private
LogicalAggregate<LogicalAggregate<LogicalAggregate<LogicalAggregate<GroupPlan>>>>
+ disassembleAggregateFunction(
+ LogicalAggregate<GroupPlan> aggregate) {
+ // Double-check to prevent incorrect changes
+ Preconditions.checkArgument(aggregate.getAggPhase() == AggPhase.LOCAL);
+ Preconditions.checkArgument(aggregate.isFinalPhase());
+ List<Expression> groupByExpressions =
aggregate.getGroupByExpressions();
+ if (groupByExpressions == null || groupByExpressions.isEmpty()) {
+ // If there are no group by expressions, in order to parallelize,
+ // we need to manually use the distinct function argument as group
by expressions
+ groupByExpressions = new
ArrayList<>(getDistinctFunctionParams(aggregate));
+ }
+ Pair<List<NamedExpression>, List<NamedExpression>> localAndGlobal =
+ disassemble(aggregate.getOutputExpressions(),
+ groupByExpressions,
+ AggPhase.LOCAL, AggPhase.GLOBAL);
+ Pair<List<NamedExpression>, List<NamedExpression>>
globalAndDistinctLocal =
+ disassemble(localAndGlobal.second,
+ groupByExpressions,
+ AggPhase.GLOBAL, AggPhase.DISTINCT_LOCAL);
+ Pair<List<NamedExpression>, List<NamedExpression>>
distinctLocalAndDistinctGlobal =
+ disassemble(globalAndDistinctLocal.second,
+ aggregate.getGroupByExpressions(),
+ AggPhase.DISTINCT_LOCAL, AggPhase.DISTINCT_GLOBAL);
+ // generate new plan
+ LogicalAggregate<GroupPlan> localAggregate = new LogicalAggregate<>(
+ groupByExpressions,
+ localAndGlobal.first,
+ true,
+ aggregate.isNormalized(),
+ false,
+ AggPhase.LOCAL,
+ aggregate.getSourceRepeat(),
+ aggregate.child()
+ );
+ LogicalAggregate<LogicalAggregate<GroupPlan>> globalAggregate = new
LogicalAggregate<>(
+ groupByExpressions,
+ globalAndDistinctLocal.first,
+ true,
+ aggregate.isNormalized(),
+ false,
+ AggPhase.GLOBAL,
+ aggregate.getSourceRepeat(),
+ localAggregate
+ );
+ LogicalAggregate<LogicalAggregate<LogicalAggregate<GroupPlan>>>
distinctLocalAggregate =
+ new LogicalAggregate<>(
+ aggregate.getGroupByExpressions(),
+ distinctLocalAndDistinctGlobal.first,
+ true,
+ aggregate.isNormalized(),
+ false,
+ AggPhase.DISTINCT_LOCAL,
+ aggregate.getSourceRepeat(),
+ globalAggregate
+ );
+ return new LogicalAggregate<>(
+ aggregate.getGroupByExpressions(),
+ distinctLocalAndDistinctGlobal.second,
+ true,
+ aggregate.isNormalized(),
+ true,
+ AggPhase.DISTINCT_GLOBAL,
+ aggregate.getSourceRepeat(),
+ distinctLocalAggregate
+ );
+ }
+
+ private Pair<List<NamedExpression>, List<NamedExpression>> disassemble(
+ List<NamedExpression> originOutputExprs,
+ List<Expression> childGroupByExprs,
+ AggPhase childPhase,
+ AggPhase parentPhase) {
+ Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
+
+ List<NamedExpression> childOutputExprs = Lists.newArrayList();
+ // The groupBy slots are placed at the beginning of the output, in
line with the stale optimiser
+ childGroupByExprs.stream().forEach(expression ->
childOutputExprs.add((SlotReference) expression));
+ for (NamedExpression originOutputExpr : originOutputExprs) {
+ Set<AggregateFunction> aggregateFunctions
+ =
originOutputExpr.collect(AggregateFunction.class::isInstance);
+ for (AggregateFunction aggregateFunction : aggregateFunctions) {
+ if (inputSubstitutionMap.containsKey(aggregateFunction)) {
+ continue;
+ }
+ AggregateFunction childAggregateFunction =
aggregateFunction.withAggregateParam(
+ aggregateFunction.getAggregateParam()
+ .withPhaseAndDisassembled(false, childPhase,
true)
+ );
+ NamedExpression childOutputExpr = new
Alias(childAggregateFunction, aggregateFunction.toSql());
+ AggregateFunction substitutionValue = aggregateFunction
+ // save the origin input types to the global aggregate
functions
+
.withAggregateParam(aggregateFunction.getAggregateParam()
+ .withPhaseAndDisassembled(true, parentPhase,
true))
+
.withChildren(Lists.newArrayList(childOutputExpr.toSlot()));
+
+ inputSubstitutionMap.put(aggregateFunction, substitutionValue);
+ childOutputExprs.add(childOutputExpr);
+ }
+ }
+
+ // 3. replace expression in parentOutputExprs
+ List<NamedExpression> parentOutputExprs = originOutputExprs.stream()
+ .map(e -> ExpressionUtils.replace(e, inputSubstitutionMap))
+ .map(NamedExpression.class::cast)
+ .collect(Collectors.toList());
+ return Pair.of(childOutputExprs, parentOutputExprs);
+ }
+
+ private List<Expression>
getDistinctFunctionParams(LogicalAggregate<GroupPlan> agg) {
+ List<Expression> result = new ArrayList<>();
+ for (NamedExpression originOutputExpr : agg.getOutputExpressions()) {
+ Set<AggregateFunction> aggregateFunctions
+ =
originOutputExpr.collect(AggregateFunction.class::isInstance);
+ for (AggregateFunction aggregateFunction : aggregateFunctions) {
+ if (aggregateFunction.isDistinct()) {
+ result.addAll(aggregateFunction.children());
+ }
+ }
+ }
+ return result;
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
index 80c799391b..e97071f0a7 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
@@ -38,7 +38,7 @@ public abstract class AggregateFunction extends BoundFunction
implements Expects
private final AggregateParam aggregateParam;
public AggregateFunction(String name, Expression... arguments) {
- this(name, AggregateParam.global(), arguments);
+ this(name, AggregateParam.finalPhase(), arguments);
}
public AggregateFunction(String name, AggregateParam aggregateParam,
Expression... arguments) {
@@ -79,7 +79,7 @@ public abstract class AggregateFunction extends BoundFunction
implements Expects
@Override
public final DataType getDataType() {
- if (aggregateParam.isGlobal) {
+ if (aggregateParam.aggPhase.isGlobal() || aggregateParam.isFinalPhase)
{
return getFinalType();
} else {
return getIntermediateTypes();
@@ -114,7 +114,11 @@ public abstract class AggregateFunction extends
BoundFunction implements Expects
}
public boolean isGlobal() {
- return aggregateParam.isGlobal;
+ return aggregateParam.aggPhase.isGlobal();
+ }
+
+ public boolean isFinalPhase() {
+ return aggregateParam.isFinalPhase;
}
public boolean isDisassembled() {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java
index 9e4b8e1394..c96e437768 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java
@@ -17,51 +17,57 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
+import org.apache.doris.nereids.trees.plans.AggPhase;
+
import com.google.common.base.Preconditions;
import java.util.Objects;
/** AggregateParam. */
public class AggregateParam {
- public final boolean isGlobal;
+ public final boolean isFinalPhase;
+
+ public final AggPhase aggPhase;
public final boolean isDistinct;
public final boolean isDisassembled;
/** AggregateParam */
- public AggregateParam(boolean isDistinct, boolean isGlobal, boolean
isDisassembled) {
+ public AggregateParam(boolean isDistinct, boolean isFinalPhase, AggPhase
aggPhase, boolean isDisassembled) {
+ this.isFinalPhase = isFinalPhase;
this.isDistinct = isDistinct;
- this.isGlobal = isGlobal;
+ this.aggPhase = aggPhase;
this.isDisassembled = isDisassembled;
- if (!isGlobal) {
- Preconditions.checkArgument(isDisassembled == true,
- "local aggregate should be disassembed");
+ if (!isFinalPhase) {
+ Preconditions.checkArgument(isDisassembled,
+ "non-final phase aggregate should be disassembed");
}
}
- public static AggregateParam global() {
- return new AggregateParam(false, true, false);
+ public static AggregateParam finalPhase() {
+ return new AggregateParam(false, true, AggPhase.LOCAL, false);
}
- public static AggregateParam distinctAndGlobal() {
- return new AggregateParam(true, true, false);
+ public static AggregateParam distinctAndFinalPhase() {
+ return new AggregateParam(true, true, AggPhase.LOCAL, false);
}
public AggregateParam withDistinct(boolean isDistinct) {
- return new AggregateParam(isDistinct, isGlobal, isDisassembled);
+ return new AggregateParam(isDistinct, isFinalPhase, aggPhase,
isDisassembled);
}
- public AggregateParam withGlobal(boolean isGlobal) {
- return new AggregateParam(isDistinct, isGlobal, isDisassembled);
+ public AggregateParam withAggPhase(AggPhase aggPhase) {
+ return new AggregateParam(isDistinct, isFinalPhase, aggPhase,
isDisassembled);
}
public AggregateParam withDisassembled(boolean isDisassembled) {
- return new AggregateParam(isDistinct, isGlobal, isDisassembled);
+ return new AggregateParam(isDistinct, isFinalPhase, aggPhase,
isDisassembled);
}
- public AggregateParam withGlobalAndDisassembled(boolean isGlobal, boolean
isDisassembled) {
- return new AggregateParam(isDistinct, isGlobal, isDisassembled);
+ public AggregateParam withPhaseAndDisassembled(boolean isFinalPhase,
AggPhase aggPhase,
+ boolean isDisassembled) {
+ return new AggregateParam(isDistinct, isFinalPhase, aggPhase,
isDisassembled);
}
@Override
@@ -74,12 +80,13 @@ public class AggregateParam {
}
AggregateParam that = (AggregateParam) o;
return isDistinct == that.isDistinct
- && Objects.equals(isGlobal, that.isGlobal)
+ && isFinalPhase == that.isFinalPhase
+ && Objects.equals(aggPhase, that.aggPhase)
&& Objects.equals(isDisassembled, that.isDisassembled);
}
@Override
public int hashCode() {
- return Objects.hash(isDistinct, isGlobal, isDisassembled);
+ return Objects.hash(isDistinct, isFinalPhase, aggPhase,
isDisassembled);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
index db5d703c4f..e94d32dfba 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
@@ -22,6 +22,7 @@ import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.AggPhase;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
@@ -35,6 +36,7 @@ import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
+import java.util.Set;
/**
* Logical Aggregate plan.
@@ -204,10 +206,41 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
.build();
}
+ public boolean isLocal() {
+ return aggPhase.isLocal();
+ }
+
public boolean isDisassembled() {
return disassembled;
}
+ /**
+ * Check if disassembling is possible
+ * @return true means that disassembling is possible
+ */
+ public boolean needDistinctDisassemble() {
+ // It is sufficient to split an aggregate function with a groupBy
expression into two stages(Local & Global),
+ // no need for four stages(Local & Global & Distinct Local & Distinct
Global)
+ if (!isFinalPhase || aggPhase != AggPhase.LOCAL) {
+ return false;
+ }
+ int distinctFunctionCount = 0;
+ for (NamedExpression originOutputExpr : outputExpressions) {
+ Set<AggregateFunction> aggregateFunctions
+ =
originOutputExpr.collect(AggregateFunction.class::isInstance);
+ for (AggregateFunction aggregateFunction : aggregateFunctions) {
+ if (aggregateFunction.isDistinct()) {
+ distinctFunctionCount++;
+ if (distinctFunctionCount > 1) {
+ return false;
+ }
+ }
+ }
+ }
+ // Only one distinct function is supported
+ return distinctFunctionCount == 1;
+ }
+
public boolean isNormalized() {
return normalized;
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
index f10b6325a2..9d4d5711e2 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
@@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
+import org.apache.doris.nereids.rules.rewrite.DistinctAggregateDisassemble;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
@@ -44,6 +45,7 @@ import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
+import java.util.ArrayList;
import java.util.List;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@@ -206,121 +208,147 @@ public class AggregateDisassembleTest implements
PatternMatchSupported {
/**
* <pre>
* the initial plan is:
- * Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age) + 2) as
c], groupByExpr: [id])
+ * Aggregate(phase: [LOCAL], outputExpr: [(COUNT(distinct age) + 2) as
c], groupByExpr: [])
* +-- childPlan(id, name, age)
* we should rewrite to:
- * Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [(COUNT(distinct age)
+ 2) as c], groupByExpr: [id])
- * +-- Aggregate(phase: [GLOBAL], outputExpr: [id, age], groupByExpr:
[id, age])
- * +-- Aggregate(phase: [LOCAL], outputExpr: [id, age], groupByExpr:
[id, age])
- * +-- childPlan(id, name, age)
+ * Aggregate(phase: [GLOBAL], outputExpr: [count(distinct c)],
groupByExpr: [])
+ * +-- Aggregate(phase: [LOCAL], outputExpr: [(COUNT(distinct age) + 2)
as c], groupByExpr: [])
+ * +-- childPlan(id, name, age)
* </pre>
*/
@Test
- public void distinctAggregateWithGroupBy() {
- List<Expression> groupExpressionList =
Lists.newArrayList(rStudent.getOutput().get(0).toSlot());
+ public void distinctAggregateWithoutGroupByApply2PhaseRule() {
+ List<Expression> groupExpressionList = new ArrayList<>();
List<NamedExpression> outputExpressionList = Lists.newArrayList(new
Alias(
- new Add(new Count(AggregateParam.distinctAndGlobal(),
rStudent.getOutput().get(2).toSlot()),
+ new Add(new Count(AggregateParam.distinctAndFinalPhase(),
rStudent.getOutput().get(2).toSlot()),
new IntegerLiteral(2)), "c"));
Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList, rStudent);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new AggregateDisassemble())
+ .matchesFromRoot(
+ logicalAggregate(
+ logicalAggregate()
+ .when(agg ->
agg.getAggPhase().equals(AggPhase.LOCAL))
+ .when(agg ->
agg.getOutputExpressions().size() == 1)
+ .when(agg ->
agg.getGroupByExpressions().isEmpty())
+ ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL))
+ .when(agg -> agg.getOutputExpressions().size() ==
1)
+ .when(agg -> agg.getGroupByExpressions().isEmpty())
+ );
+ }
+
+ @Test
+ public void distinctWithNormalAggregateFunctionApply2PhaseRule() {
+ List<Expression> groupExpressionList =
Lists.newArrayList(rStudent.getOutput().get(0).toSlot());
+ List<NamedExpression> outputExpressionList = Lists.newArrayList(
+ new Alias(new Count(AggregateParam.distinctAndFinalPhase(),
rStudent.getOutput().get(2).toSlot()), "c"),
+ new Alias(new Sum(rStudent.getOutput().get(0).toSlot()),
"sum"));
+ Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList, rStudent);
+
// check local:
// id
Expression localOutput0 = rStudent.getOutput().get(0);
- // age
- Expression localOutput1 = rStudent.getOutput().get(2);
+ // count
+ Count localOutput1 = new Count(new AggregateParam(true, false,
AggPhase.LOCAL, true), rStudent.getOutput().get(2).toSlot());
+ // sum
+ Sum localOutput2 = new Sum(new AggregateParam(false, false,
AggPhase.LOCAL, true), rStudent.getOutput().get(0).toSlot());
// id
Expression localGroupBy0 = rStudent.getOutput().get(0);
- // age
- Expression localGroupBy1 = rStudent.getOutput().get(2);
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyTopDown(new AggregateDisassemble())
.matchesFromRoot(
- logicalAggregate(
- logicalAggregate(
- logicalAggregate()
- .when(agg ->
agg.getAggPhase().equals(AggPhase.LOCAL))
- .when(agg ->
agg.getOutputExpressions().get(0).equals(localOutput0))
- .when(agg ->
agg.getOutputExpressions().get(1).equals(localOutput1))
- .when(agg ->
agg.getGroupByExpressions().get(0).equals(localGroupBy0))
- .when(agg ->
agg.getGroupByExpressions().get(1).equals(localGroupBy1))
- ).when(agg ->
agg.getAggPhase().equals(AggPhase.GLOBAL))
- .when(agg ->
agg.getOutputExpressions().get(0)
-
.equals(agg.child().getOutputExpressions().get(0)))
- .when(agg ->
agg.getOutputExpressions().get(1)
-
.equals(agg.child().getOutputExpressions().get(1)))
- .when(agg ->
agg.getGroupByExpressions().get(0)
-
.equals(agg.child().getOutputExpressions().get(0)))
- .when(agg ->
agg.getGroupByExpressions().get(1)
-
.equals(agg.child().getOutputExpressions().get(1)))
- ).when(agg ->
agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL))
- .when(agg -> agg.getOutputExpressions().size()
== 1)
- .when(agg -> agg.getOutputExpressions().get(0)
instanceof Alias)
- .when(agg ->
agg.getOutputExpressions().get(0).child(0) instanceof Add)
- .when(agg -> agg.getGroupByExpressions().get(0)
-
.equals(agg.child().child().getOutputExpressions().get(0)))
- .when(agg ->
agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get(
- 0).getExprId())
+ logicalAggregate(
+ logicalAggregate()
+ .when(agg ->
agg.getAggPhase().equals(AggPhase.LOCAL))
+ .when(agg ->
agg.getOutputExpressions().get(0).equals(localOutput0))
+ .when(agg ->
agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
+ .when(agg ->
agg.getOutputExpressions().get(2).child(0).equals(localOutput2))
+ .when(agg ->
agg.getGroupByExpressions().get(0).equals(localGroupBy0))
+ ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL))
+ .when(agg -> {
+ Slot child =
agg.child().getOutputExpressions().get(1).toSlot();
+
Assertions.assertTrue(agg.getOutputExpressions().get(0).child(0) instanceof
Count);
+ return
agg.getOutputExpressions().get(0).child(0).child(0).equals(child);
+ })
+ .when(agg -> {
+ Slot child =
agg.child().getOutputExpressions().get(2).toSlot();
+
Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof
Sum);
+ return ((Sum)
agg.getOutputExpressions().get(1).child(0)).child().equals(child);
+ })
+ .when(agg -> agg.getGroupByExpressions().get(0)
+
.equals(agg.child().getOutputExpressions().get(0)))
);
}
@Test
- public void distinctWithNormalAggregateFunction() {
+ public void distinctWithNormalAggregateFunctionApply4PhaseRule() {
List<Expression> groupExpressionList =
Lists.newArrayList(rStudent.getOutput().get(0).toSlot());
List<NamedExpression> outputExpressionList = Lists.newArrayList(
- new Alias(new Count(AggregateParam.distinctAndGlobal(),
rStudent.getOutput().get(2).toSlot()), "c"),
+ new Alias(new Count(AggregateParam.distinctAndFinalPhase(),
rStudent.getOutput().get(2).toSlot()), "c"),
new Alias(new Sum(rStudent.getOutput().get(0).toSlot()),
"sum"));
Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList, rStudent);
// check local:
// id
Expression localOutput0 = rStudent.getOutput().get(0);
+ // count
+ Count localOutput1 = new Count(new AggregateParam(true, false,
AggPhase.LOCAL, true), rStudent.getOutput().get(2).toSlot());
// sum
- Sum localOutput1 = new Sum(new AggregateParam(false, false, true),
rStudent.getOutput().get(0).toSlot());
- // age
- Expression localOutput2 = rStudent.getOutput().get(2);
+ Sum localOutput2 = new Sum(new AggregateParam(false, false,
AggPhase.LOCAL, true), rStudent.getOutput().get(0).toSlot());
// id
Expression localGroupBy0 = rStudent.getOutput().get(0);
- // age
- Expression localGroupBy1 = rStudent.getOutput().get(2);
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
- .applyTopDown(new AggregateDisassemble())
+ .applyTopDown(new DistinctAggregateDisassemble())
.matchesFromRoot(
+ logicalAggregate(
+ logicalAggregate(
logicalAggregate(
- logicalAggregate(
- logicalAggregate()
- .when(agg ->
agg.getAggPhase().equals(AggPhase.LOCAL))
- .when(agg ->
agg.getOutputExpressions().get(0).equals(localOutput0))
- .when(agg ->
agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
- .when(agg ->
agg.getOutputExpressions().get(2).equals(localOutput2))
- .when(agg ->
agg.getGroupByExpressions().get(0).equals(localGroupBy0))
- .when(agg ->
agg.getGroupByExpressions().get(1).equals(localGroupBy1))
- ).when(agg ->
agg.getAggPhase().equals(AggPhase.GLOBAL))
- .when(agg ->
agg.getOutputExpressions().get(0)
-
.equals(agg.child().getOutputExpressions().get(0)))
- .when(agg -> {
- Slot child =
agg.child().getOutputExpressions().get(1).toSlot();
-
Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof
Sum);
- return ((Sum)
agg.getOutputExpressions().get(1).child(0)).child().equals(child);
- })
- .when(agg ->
agg.getOutputExpressions().get(2)
-
.equals(agg.child().getOutputExpressions().get(2)))
- .when(agg ->
agg.getGroupByExpressions().get(0)
-
.equals(agg.child().getOutputExpressions().get(0)))
- .when(agg ->
agg.getGroupByExpressions().get(1)
-
.equals(agg.child().getOutputExpressions().get(2)))
- ).when(agg ->
agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL))
- .when(agg -> agg.getOutputExpressions().size()
== 2)
- .when(agg -> agg.getOutputExpressions().get(0)
instanceof Alias)
- .when(agg ->
agg.getOutputExpressions().get(0).child(0) instanceof Count)
- .when(agg ->
agg.getOutputExpressions().get(1).child(0) instanceof Sum)
- .when(agg ->
agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get(
- 0).getExprId())
- .when(agg ->
agg.getOutputExpressions().get(1).getExprId() == outputExpressionList.get(
- 1).getExprId())
+ logicalAggregate()
+ .when(agg ->
agg.getAggPhase().equals(AggPhase.LOCAL))
+ .when(agg ->
agg.getOutputExpressions().get(0).equals(localOutput0))
+ .when(agg ->
agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
+ .when(agg ->
agg.getOutputExpressions().get(2).child(0).equals(localOutput2))
+ .when(agg ->
agg.getGroupByExpressions().get(0).equals(localGroupBy0))
+ ).when(agg ->
agg.getAggPhase().equals(AggPhase.GLOBAL))
+ .when(agg -> {
+ Slot child =
agg.child().getOutputExpressions().get(1).toSlot();
+
Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof
Count);
+ return
agg.getOutputExpressions().get(1).child(0).child(0).equals(child);
+ })
+ .when(agg -> {
+ Slot child =
agg.child().getOutputExpressions().get(2).toSlot();
+
Assertions.assertTrue(agg.getOutputExpressions().get(2).child(0) instanceof
Sum);
+ return ((Sum)
agg.getOutputExpressions().get(2).child(0)).child().equals(child);
+ })
.when(agg -> agg.getGroupByExpressions().get(0)
-
.equals(agg.child().child().getOutputExpressions().get(0)))
+
.equals(agg.child().getOutputExpressions().get(0)))
+ ).when(agg ->
agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL))
+ .when(agg -> {
+ Slot child =
agg.child().getOutputExpressions().get(1).toSlot();
+
Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof
Count);
+ return
agg.getOutputExpressions().get(1).child(0).child(0).equals(child);
+ })
+ .when(agg -> {
+ Slot child =
agg.child().getOutputExpressions().get(2).toSlot();
+
Assertions.assertTrue(agg.getOutputExpressions().get(2).child(0) instanceof
Sum);
+ return ((Sum)
agg.getOutputExpressions().get(2).child(0)).child().equals(child);
+ })
+ .when(agg -> agg.getGroupByExpressions().get(0)
+
.equals(agg.child().getOutputExpressions().get(0)))
+ ).when(agg ->
agg.getAggPhase().equals(AggPhase.DISTINCT_GLOBAL))
+ .when(agg -> agg.getOutputExpressions().size() == 2)
+ .when(agg -> agg.getOutputExpressions().get(0)
instanceof Alias)
+ .when(agg ->
agg.getOutputExpressions().get(0).child(0) instanceof Count)
+ .when(agg ->
agg.getOutputExpressions().get(1).child(0) instanceof Sum)
+ .when(agg ->
agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get(
+ 0).getExprId())
+ .when(agg ->
agg.getOutputExpressions().get(1).getExprId() == outputExpressionList.get(
+ 1).getExprId())
+ .when(agg -> agg.getGroupByExpressions().get(0)
+
.equals(agg.child().child().child().getOutputExpressions().get(0)))
);
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
index dd8ea662d0..c3bd3830cc 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
@@ -177,13 +177,13 @@ public class ExpressionEqualsTest {
Assertions.assertEquals(count1, count2);
Assertions.assertEquals(count1.hashCode(), count2.hashCode());
- Count count3 = new Count(AggregateParam.distinctAndGlobal(), child1);
- Count count4 = new Count(AggregateParam.distinctAndGlobal(), child2);
+ Count count3 = new Count(AggregateParam.distinctAndFinalPhase(),
child1);
+ Count count4 = new Count(AggregateParam.distinctAndFinalPhase(),
child2);
Assertions.assertEquals(count3, count4);
Assertions.assertEquals(count3.hashCode(), count4.hashCode());
// bad case
- Count count5 = new Count(AggregateParam.distinctAndGlobal(), child1);
+ Count count5 = new Count(AggregateParam.distinctAndFinalPhase(),
child1);
Count count6 = new Count(child2);
Assertions.assertNotEquals(count5, count6);
Assertions.assertNotEquals(count5.hashCode(), count6.hashCode());
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]