This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 5b6d48ed5b [feature](nereids) support distinct count (#12159)
5b6d48ed5b is described below
commit 5b6d48ed5b6db033607224523579da0a77d957f2
Author: yinzhijian <[email protected]>
AuthorDate: Thu Sep 15 13:01:47 2022 +0800
[feature](nereids) support distinct count (#12159)
support distinct count with group by clause.
for example:
SELECT count(distinct c_custkey + 1) FROM customer group by c_nation;
TODO: support distinct count without group by clause.
---
.../glue/translator/ExpressionTranslator.java | 2 +
.../glue/translator/PhysicalPlanTranslator.java | 17 +-
.../properties/ChildOutputPropertyDeriver.java | 2 +-
.../nereids/properties/RequestPropertyDeriver.java | 7 +-
.../doris/nereids/rules/analysis/BindFunction.java | 12 ++
.../expression/rewrite/ExpressionRewrite.java | 2 +-
.../LogicalAggToPhysicalHashAgg.java | 1 +
.../rules/rewrite/AggregateDisassemble.java | 235 +++++++++++++++------
.../rules/rewrite/logical/NormalizeAggregate.java | 2 +-
.../expressions/functions/AggregateFunction.java | 31 +++
.../nereids/trees/expressions/functions/Count.java | 12 +-
.../trees/plans/logical/LogicalAggregate.java | 36 +++-
.../trees/plans/physical/PhysicalAggregate.java | 35 ++-
.../doris/nereids/parser/HavingClauseTest.java | 4 +-
.../properties/ChildOutputPropertyDeriverTest.java | 2 +
.../properties/RequestPropertyDeriverTest.java | 3 +
.../rewrite/logical/AggregateDisassembleTest.java | 81 +++++++
.../trees/expressions/ExpressionEqualsTest.java | 20 ++
.../doris/nereids/trees/plans/PlanEqualsTest.java | 12 +-
.../data/nereids_syntax_p0/function.out | 5 +
.../suites/nereids_syntax_p0/function.groovy | 4 +
21 files changed, 416 insertions(+), 109 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
index 017ec6b5b7..1c3f59361d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
@@ -256,6 +256,8 @@ public class ExpressionTranslator extends
DefaultExpressionVisitor<Expr, PlanTra
Count count = (Count) function;
if (count.isStar()) {
return new FunctionCallExpr(function.getName(),
FunctionParams.createStarParam());
+ } else if (count.isDistinct()) {
+ return new FunctionCallExpr(function.getName(), new
FunctionParams(true, paramList));
}
}
return new FunctionCallExpr(function.getName(), paramList);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index d47bf1c1aa..a783567a70 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -191,12 +191,17 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
// 3. generate output tuple
List<Slot> slotList = Lists.newArrayList();
TupleDescriptor outputTupleDesc;
- if (aggregate.getAggPhase() == AggPhase.GLOBAL) {
+ if (aggregate.getAggPhase() == AggPhase.LOCAL) {
+ outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null,
context);
+ } else if ((aggregate.getAggPhase() == AggPhase.GLOBAL &&
aggregate.isFinalPhase())
+ || aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
slotList.addAll(groupSlotList);
slotList.addAll(aggFunctionOutput);
outputTupleDesc = generateTupleDesc(slotList, null, context);
} else {
- outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null,
context);
+ // In the distinct agg scenario, global shares local's desc
+ AggregationNode localAggNode = (AggregationNode)
inputPlanFragment.getPlanRoot().getChild(0);
+ outputTupleDesc = localAggNode.getAggInfo().getOutputTupleDesc();
}
if (aggregate.getAggPhase() == AggPhase.GLOBAL) {
@@ -204,6 +209,13 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
execAggregateFunction.setMergeForNereids(true);
}
}
+ if (aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
+ for (FunctionCallExpr execAggregateFunction :
execAggregateFunctions) {
+ if (!execAggregateFunction.isDistinct()) {
+ execAggregateFunction.setMergeForNereids(true);
+ }
+ }
+ }
AggregateInfo aggInfo = AggregateInfo.create(execGroupingExpressions,
execAggregateFunctions, outputTupleDesc,
outputTupleDesc, aggregate.getAggPhase().toExec());
AggregationNode aggregationNode = new
AggregationNode(context.nextPlanNodeId(),
@@ -216,6 +228,7 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
aggregationNode.setIntermediateTuple();
break;
case GLOBAL:
+ case DISTINCT_LOCAL:
if (currentFragment.getPlanRoot() instanceof ExchangeNode) {
ExchangeNode exchangeNode = (ExchangeNode)
currentFragment.getPlanRoot();
currentFragment = new
PlanFragment(context.nextFragmentId(), exchangeNode, mergePartition);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
index 1d7974e161..ba8976e71d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
@@ -80,12 +80,12 @@ public class ChildOutputPropertyDeriver extends
PlanVisitor<PhysicalProperties,
case LOCAL:
return new
PhysicalProperties(childOutputProperty.getDistributionSpec());
case GLOBAL:
+ case DISTINCT_LOCAL:
List<ExprId> columns = agg.getPartitionExpressions().stream()
.map(SlotReference.class::cast)
.map(SlotReference::getExprId)
.collect(Collectors.toList());
return PhysicalProperties.createHash(new
DistributionSpecHash(columns, ShuffleType.AGGREGATE));
- case DISTINCT_LOCAL:
case DISTINCT_GLOBAL:
default:
throw new RuntimeException("Could not derive output properties
for agg phase: " + agg.getAggPhase());
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
index 0c2f2089ae..67a9032f85 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
@@ -25,6 +25,7 @@ import
org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.plans.AggPhase;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
@@ -82,14 +83,16 @@ public class RequestPropertyDeriver extends
PlanVisitor<Void, PlanContext> {
addToRequestPropertyToChildren(PhysicalProperties.ANY);
return null;
}
-
+ if (agg.getAggPhase() == AggPhase.GLOBAL && !agg.isFinalPhase()) {
+ addToRequestPropertyToChildren(requestPropertyFromParent);
+ return null;
+ }
// 2. second phase agg, need to return shuffle with partition key
List<Expression> partitionExpressions = agg.getPartitionExpressions();
if (partitionExpressions.isEmpty()) {
addToRequestPropertyToChildren(PhysicalProperties.GATHER);
return null;
}
-
// TODO: when parent is a join node,
// use requestPropertyFromParent to keep column order as join to
avoid shuffle again.
if
(partitionExpressions.stream().allMatch(SlotReference.class::isInstance)) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
index 40782b2e28..fcef341f5f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
@@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
+import org.apache.doris.nereids.trees.expressions.functions.Count;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.GroupPlan;
@@ -115,6 +116,17 @@ public class BindFunction implements AnalysisRuleFactory {
@Override
public BoundFunction visitUnboundFunction(UnboundFunction
unboundFunction, Env env) {
+ // FunctionRegistry can't support boolean arg now, tricky here.
+ if (unboundFunction.getName().equalsIgnoreCase("count")) {
+ List<Expression> arguments = unboundFunction.getArguments();
+ if ((arguments.size() == 0 && unboundFunction.isStar()) ||
arguments.stream()
+ .allMatch(Expression::isConstant)) {
+ return new Count();
+ }
+ if (arguments.size() == 1) {
+ return new Count(unboundFunction.getArguments().get(0),
unboundFunction.isDistinct());
+ }
+ }
FunctionRegistry functionRegistry = env.getFunctionRegistry();
String functionName = unboundFunction.getName();
FunctionBuilder builder = functionRegistry.findFunctionBuilder(
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
index d808183c24..f660ee0ec0 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
@@ -126,7 +126,7 @@ public class ExpressionRewrite implements
RewriteRuleFactory {
return agg;
}
return new LogicalAggregate<>(newGroupByExprs,
newOutputExpressions,
- agg.isDisassembled(), agg.isNormalized(),
agg.getAggPhase(), agg.child());
+ agg.isDisassembled(), agg.isNormalized(),
agg.isFinalPhase(), agg.getAggPhase(), agg.child());
}).toRule(RuleType.REWRITE_AGG_EXPRESSION);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
index ecc59393d4..4e4d52b551 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
@@ -36,6 +36,7 @@ public class LogicalAggToPhysicalHashAgg extends
OneImplementationRuleFactory {
ImmutableList.of(),
agg.getAggPhase(),
false,
+ agg.isFinalPhase(),
agg.getLogicalProperties(),
agg.child())
).toRule(RuleType.LOGICAL_AGG_TO_PHYSICAL_HASH_AGG_RULE);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
index 7d68752e07..4166d9db5d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
@@ -49,92 +49,187 @@ import java.util.stream.Collectors;
* +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as
b], groupByExpr: [k + 1])
* +-- childPlan
*
+ * Distinct Agg With Group By Processing:
+ * If we have a query: SELECT count(distinct v1 * v2) + 1 FROM t GROUP BY k + 1
+ * the initial plan is:
+ * Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1,
Alias(COUNT(distinct v1 * v2) + 1) #2]
+ * , groupByExpr: [k + 1])
+ * +-- childPlan
+ * we should rewrite to:
+ * Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [Alias(b) #1,
Alias(COUNT(distinct a) + 1) #2], groupByExpr: [b])
+ * +-- Aggregate(phase: [GLOBAL], outputExpr: [b, a], groupByExpr: [b, a])
+ * +-- Aggregate(phase: [LOCAL], outputExpr: [(k + 1) as b, (v1 * v2) as
a], groupByExpr: [k + 1, a])
+ * +-- childPlan
+ *
* TODO:
* 1. use different class represent different phase aggregate
* 2. if instance count is 1, shouldn't disassemble the agg plan
*/
public class AggregateDisassemble extends OneRewriteRuleFactory {
+ // used in secondDisassemble to transform local expressions into global
+ private final Map<Expression, Expression> globalOutputSubstitutionMap =
Maps.newHashMap();
+ // used in secondDisassemble to transform local expressions into global
+ private final Map<Expression, Expression> globalGroupBySubstitutionMap =
Maps.newHashMap();
+ // used to indicate the existence of a distinct function for the entire
phase
+ private boolean hasDistinctAgg = false;
@Override
public Rule build() {
return logicalAggregate().when(agg ->
!agg.isDisassembled()).thenApply(ctx -> {
LogicalAggregate<GroupPlan> aggregate = ctx.root;
- List<NamedExpression> originOutputExprs =
aggregate.getOutputExpressions();
- List<Expression> originGroupByExprs =
aggregate.getGroupByExpressions();
+ LogicalAggregate firstAggregate = firstDisassemble(aggregate);
+ if (!hasDistinctAgg) {
+ return firstAggregate;
+ }
+ return secondDisassemble(firstAggregate);
+ }).toRule(RuleType.AGGREGATE_DISASSEMBLE);
+ }
+
+ // only support distinct function with group by
+ // TODO: support distinct function without group by. (add second global
phase)
+ private LogicalAggregate
secondDisassemble(LogicalAggregate<LogicalAggregate> aggregate) {
+ LogicalAggregate<GroupPlan> local = aggregate.child();
+ // replace expression in globalOutputExprs and globalGroupByExprs
+ List<NamedExpression> globalOutputExprs =
local.getOutputExpressions().stream()
+ .map(e -> ExpressionUtils.replace(e,
globalOutputSubstitutionMap))
+ .map(NamedExpression.class::cast)
+ .collect(Collectors.toList());
+ List<Expression> globalGroupByExprs =
local.getGroupByExpressions().stream()
+ .map(e -> ExpressionUtils.replace(e,
globalGroupBySubstitutionMap))
+ .collect(Collectors.toList());
+
+ // generate new plan
+ LogicalAggregate globalAggregate = new LogicalAggregate<>(
+ globalGroupByExprs,
+ globalOutputExprs,
+ true,
+ aggregate.isNormalized(),
+ false,
+ AggPhase.GLOBAL,
+ local
+ );
+ return new LogicalAggregate<>(
+ aggregate.getGroupByExpressions(),
+ aggregate.getOutputExpressions(),
+ true,
+ aggregate.isNormalized(),
+ true,
+ AggPhase.DISTINCT_LOCAL,
+ globalAggregate
+ );
+ }
+
+ private LogicalAggregate firstDisassemble(LogicalAggregate<GroupPlan>
aggregate) {
+ List<NamedExpression> originOutputExprs =
aggregate.getOutputExpressions();
+ List<Expression> originGroupByExprs =
aggregate.getGroupByExpressions();
+ Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
- // 1. generate a map from local aggregate output to global
aggregate expr substitution.
- // inputSubstitutionMap use for replacing expression in global
aggregate
- // replace rule is:
- // a: Expression is a group by key and is a slot reference.
e.g. group by k1
- // b. Expression is a group by key and is an expression.
e.g. group by k1 + 1
- // c. Expression is an aggregate function. e.g. sum(v1) in
select list
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | situation | origin expression | local output expression
| expression in global aggregate |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | a | Ref(k1)#1 | Ref(k1)#1
| Ref(k1)#1 |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2
| Ref(key)#2 |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3
| AF(af#3) |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction,
#x: ExprId x
- // 2. collect local aggregate output expressions and local
aggregate group by expression list
- Map<Expression, Expression> inputSubstitutionMap =
Maps.newHashMap();
- List<Expression> localGroupByExprs =
aggregate.getGroupByExpressions();
- List<NamedExpression> localOutputExprs = Lists.newArrayList();
- for (Expression originGroupByExpr : originGroupByExprs) {
- if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+ // 1. generate a map from local aggregate output to global aggregate
expr substitution.
+ // inputSubstitutionMap use for replacing expression in global
aggregate
+ // replace rule is:
+ // a: Expression is a group by key and is a slot reference.
e.g. group by k1
+ // b. Expression is a group by key and is an expression. e.g.
group by k1 + 1
+ // c. Expression is an aggregate function. e.g. sum(v1) in
select list
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // | situation | origin expression | local output expression |
expression in global aggregate |
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // | a | Ref(k1)#1 | Ref(k1)#1 |
Ref(k1)#1 |
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2 |
Ref(key)#2 |
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3 |
AF(af#3) |
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x:
ExprId x
+ // 2. collect local aggregate output expressions and local aggregate
group by expression list
+ List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
+ List<NamedExpression> localOutputExprs = Lists.newArrayList();
+ for (Expression originGroupByExpr : originGroupByExprs) {
+ if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+ continue;
+ }
+ if (originGroupByExpr instanceof SlotReference) {
+ inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
+ globalOutputSubstitutionMap.put(originGroupByExpr,
originGroupByExpr);
+ globalGroupBySubstitutionMap.put(originGroupByExpr,
originGroupByExpr);
+ localOutputExprs.add((SlotReference) originGroupByExpr);
+ } else {
+ NamedExpression localOutputExpr = new Alias(originGroupByExpr,
originGroupByExpr.toSql());
+ inputSubstitutionMap.put(originGroupByExpr,
localOutputExpr.toSlot());
+ globalOutputSubstitutionMap.put(localOutputExpr,
localOutputExpr.toSlot());
+ globalGroupBySubstitutionMap.put(originGroupByExpr,
localOutputExpr.toSlot());
+ localOutputExprs.add(localOutputExpr);
+ }
+ }
+ List<Expression> distinctExprsForLocalGroupBy = Lists.newArrayList();
+ List<NamedExpression> distinctExprsForLocalOutput =
Lists.newArrayList();
+ for (NamedExpression originOutputExpr : originOutputExprs) {
+ Set<AggregateFunction> aggregateFunctions
+ =
originOutputExpr.collect(AggregateFunction.class::isInstance);
+ for (AggregateFunction aggregateFunction : aggregateFunctions) {
+ if (inputSubstitutionMap.containsKey(aggregateFunction)) {
continue;
}
- if (originGroupByExpr instanceof SlotReference) {
- inputSubstitutionMap.put(originGroupByExpr,
originGroupByExpr);
- localOutputExprs.add((SlotReference) originGroupByExpr);
- } else {
- NamedExpression localOutputExpr = new
Alias(originGroupByExpr, originGroupByExpr.toSql());
- inputSubstitutionMap.put(originGroupByExpr,
localOutputExpr.toSlot());
- localOutputExprs.add(localOutputExpr);
- }
- }
- for (NamedExpression originOutputExpr : originOutputExprs) {
- Set<AggregateFunction> aggregateFunctions
- =
originOutputExpr.collect(AggregateFunction.class::isInstance);
- for (AggregateFunction aggregateFunction : aggregateFunctions)
{
- if (inputSubstitutionMap.containsKey(aggregateFunction)) {
- continue;
+ if (aggregateFunction.isDistinct()) {
+ hasDistinctAgg = true;
+ for (Expression expr : aggregateFunction.children()) {
+ if (expr instanceof SlotReference) {
+ distinctExprsForLocalOutput.add((SlotReference)
expr);
+ if (!inputSubstitutionMap.containsKey(expr)) {
+ inputSubstitutionMap.put(expr, expr);
+ globalOutputSubstitutionMap.put(expr, expr);
+ globalGroupBySubstitutionMap.put(expr, expr);
+ }
+ } else {
+ NamedExpression globalOutputExpr = new Alias(expr,
expr.toSql());
+ distinctExprsForLocalOutput.add(globalOutputExpr);
+ if (!inputSubstitutionMap.containsKey(expr)) {
+ inputSubstitutionMap.put(expr,
globalOutputExpr.toSlot());
+
globalOutputSubstitutionMap.put(globalOutputExpr, globalOutputExpr.toSlot());
+ globalGroupBySubstitutionMap.put(expr,
globalOutputExpr.toSlot());
+ }
+ }
+ distinctExprsForLocalGroupBy.add(expr);
}
- NamedExpression localOutputExpr = new
Alias(aggregateFunction, aggregateFunction.toSql());
- Expression substitutionValue =
aggregateFunction.withChildren(
- Lists.newArrayList(localOutputExpr.toSlot()));
- inputSubstitutionMap.put(aggregateFunction,
substitutionValue);
- localOutputExprs.add(localOutputExpr);
+ continue;
}
+ NamedExpression localOutputExpr = new Alias(aggregateFunction,
aggregateFunction.toSql());
+ Expression substitutionValue = aggregateFunction.withChildren(
+ Lists.newArrayList(localOutputExpr.toSlot()));
+ inputSubstitutionMap.put(aggregateFunction, substitutionValue);
+ globalOutputSubstitutionMap.put(aggregateFunction,
substitutionValue);
+ localOutputExprs.add(localOutputExpr);
}
+ }
- // 3. replace expression in globalOutputExprs and
globalGroupByExprs
- List<NamedExpression> globalOutputExprs =
aggregate.getOutputExpressions().stream()
- .map(e -> ExpressionUtils.replace(e, inputSubstitutionMap))
- .map(NamedExpression.class::cast)
- .collect(Collectors.toList());
- List<Expression> globalGroupByExprs = localGroupByExprs.stream()
- .map(e -> ExpressionUtils.replace(e,
inputSubstitutionMap)).collect(Collectors.toList());
-
- // 4. generate new plan
- LogicalAggregate localAggregate = new LogicalAggregate<>(
- localGroupByExprs,
- localOutputExprs,
- true,
- aggregate.isNormalized(),
- AggPhase.LOCAL,
- aggregate.child()
- );
- return new LogicalAggregate<>(
- globalGroupByExprs,
- globalOutputExprs,
- true,
- aggregate.isNormalized(),
- AggPhase.GLOBAL,
- localAggregate
- );
- }).toRule(RuleType.AGGREGATE_DISASSEMBLE);
+ // 3. replace expression in globalOutputExprs and globalGroupByExprs
+ List<NamedExpression> globalOutputExprs =
aggregate.getOutputExpressions().stream()
+ .map(e -> ExpressionUtils.replace(e, inputSubstitutionMap))
+ .map(NamedExpression.class::cast)
+ .collect(Collectors.toList());
+ List<Expression> globalGroupByExprs = localGroupByExprs.stream()
+ .map(e -> ExpressionUtils.replace(e,
inputSubstitutionMap)).collect(Collectors.toList());
+ // To avoid repeated substitution of distinct expressions,
+ // here the expressions are put into the local after the substitution
is completed
+ localOutputExprs.addAll(distinctExprsForLocalOutput);
+ localGroupByExprs.addAll(distinctExprsForLocalGroupBy);
+ // 4. generate new plan
+ LogicalAggregate localAggregate = new LogicalAggregate<>(
+ localGroupByExprs,
+ localOutputExprs,
+ true,
+ aggregate.isNormalized(),
+ false,
+ AggPhase.LOCAL,
+ aggregate.child()
+ );
+ return new LogicalAggregate<>(
+ globalGroupByExprs,
+ globalOutputExprs,
+ true,
+ aggregate.isNormalized(),
+ true,
+ AggPhase.GLOBAL,
+ localAggregate
+ );
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
index 0fe139b85b..45a4a3c027 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
@@ -124,7 +124,7 @@ public class NormalizeAggregate extends
OneRewriteRuleFactory {
root = new LogicalProject<>(bottomProjections, root);
}
root = new LogicalAggregate<>(newKeys, newOutputs,
aggregate.isDisassembled(),
- true, aggregate.getAggPhase(), root);
+ true, aggregate.isFinalPhase(), aggregate.getAggPhase(),
root);
List<NamedExpression> projections = outputs.stream()
.map(e -> ExpressionUtils.replace(e, substitutionMap))
.map(NamedExpression.class::cast)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java
index 73de61a058..69572b070a 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java
@@ -21,19 +21,50 @@ import
org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
+import java.util.Objects;
+
/**
* The function which consume arguments in lots of rows and product one value.
*/
public abstract class AggregateFunction extends BoundFunction {
private DataType intermediate;
+ private final boolean isDistinct;
public AggregateFunction(String name, Expression... arguments) {
super(name, arguments);
+ isDistinct = false;
+ }
+
+ public AggregateFunction(String name, boolean isDistinct, Expression...
arguments) {
+ super(name, arguments);
+ this.isDistinct = isDistinct;
}
public abstract DataType getIntermediateType();
+ public boolean isDistinct() {
+ return isDistinct;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ AggregateFunction that = (AggregateFunction) o;
+ return Objects.equals(isDistinct, that.isDistinct) &&
Objects.equals(intermediate, that.intermediate)
+ && Objects.equals(getName(), that.getName()) &&
Objects.equals(children, that.children);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(isDistinct, intermediate, getName(), children);
+ }
+
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitAggregateFunction(this, context);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Count.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Count.java
index a31122ab7a..e594671733 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Count.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Count.java
@@ -37,8 +37,8 @@ public class Count extends AggregateFunction {
this.isStar = true;
}
- public Count(Expression child) {
- super("count", child);
+ public Count(Expression child, boolean isDistinct) {
+ super("count", isDistinct, child);
this.isStar = false;
}
@@ -62,7 +62,7 @@ public class Count extends AggregateFunction {
if (children.size() == 0) {
return new Count();
}
- return new Count(children.get(0));
+ return new Count(children.get(0), isDistinct());
}
@Override
@@ -79,6 +79,9 @@ public class Count extends AggregateFunction {
.stream()
.map(Expression::toSql)
.collect(Collectors.joining(", "));
+ if (isDistinct()) {
+ return "count(distinct " + args + ")";
+ }
return "count(" + args + ")";
}
@@ -91,6 +94,9 @@ public class Count extends AggregateFunction {
.stream()
.map(Expression::toString)
.collect(Collectors.joining(", "));
+ if (isDistinct()) {
+ return "count(distinct " + args + ")";
+ }
return "count(" + args + ")";
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
index cbe9e402ef..0cca04950d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
@@ -59,6 +59,13 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
private final List<NamedExpression> outputExpressions;
private final AggPhase aggPhase;
+ // use for scenes containing distinct agg
+ // 1. If there are LOCAL and GLOBAL phases, global is the final phase
+ // 2. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL
is the final phase
+ // 3. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases,
+ // DISTINCT_GLOBAL is the final phase
+ private final boolean isFinalPhase;
+
/**
* Desc: Constructor for LogicalAggregate.
*/
@@ -66,7 +73,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions,
CHILD_TYPE child) {
- this(groupByExpressions, outputExpressions, false, false,
AggPhase.GLOBAL, child);
+ this(groupByExpressions, outputExpressions, false, false, true,
AggPhase.GLOBAL, child);
}
public LogicalAggregate(
@@ -74,9 +81,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
List<NamedExpression> outputExpressions,
boolean disassembled,
boolean normalized,
+ boolean isFinalPhase,
AggPhase aggPhase,
CHILD_TYPE child) {
- this(groupByExpressions, outputExpressions, disassembled, normalized,
+ this(groupByExpressions, outputExpressions, disassembled, normalized,
isFinalPhase,
aggPhase, Optional.empty(), Optional.empty(), child);
}
@@ -88,6 +96,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
List<NamedExpression> outputExpressions,
boolean disassembled,
boolean normalized,
+ boolean isFinalPhase,
AggPhase aggPhase,
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties,
@@ -97,6 +106,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
this.outputExpressions = outputExpressions;
this.disassembled = disassembled;
this.normalized = normalized;
+ this.isFinalPhase = isFinalPhase;
this.aggPhase = aggPhase;
}
@@ -149,6 +159,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
return normalized;
}
+ public boolean isFinalPhase() {
+ return isFinalPhase;
+ }
+
/**
* Determine the equality with another plan
*/
@@ -164,37 +178,37 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
&& Objects.equals(outputExpressions, that.outputExpressions)
&& aggPhase == that.aggPhase
&& disassembled == that.disassembled
- && normalized == that.normalized;
+ && normalized == that.normalized
+ && isFinalPhase == that.isFinalPhase;
}
@Override
public int hashCode() {
- return Objects.hash(groupByExpressions, outputExpressions, aggPhase,
normalized, disassembled);
+ return Objects.hash(groupByExpressions, outputExpressions, aggPhase,
normalized, disassembled, isFinalPhase);
}
@Override
public LogicalAggregate<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
- disassembled, normalized, aggPhase, children.get(0));
+ disassembled, normalized, isFinalPhase, aggPhase,
children.get(0));
}
@Override
public LogicalAggregate<Plan>
withGroupExpression(Optional<GroupExpression> groupExpression) {
- return new LogicalAggregate<>(groupByExpressions, outputExpressions,
- disassembled, normalized, aggPhase, groupExpression,
Optional.of(getLogicalProperties()),
- children.get(0));
+ return new LogicalAggregate<>(groupByExpressions, outputExpressions,
disassembled, normalized, isFinalPhase,
+ aggPhase, groupExpression,
Optional.of(getLogicalProperties()), children.get(0));
}
@Override
public LogicalAggregate<Plan>
withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
- return new LogicalAggregate<>(groupByExpressions, outputExpressions,
- disassembled, normalized, aggPhase, Optional.empty(),
logicalProperties, children.get(0));
+ return new LogicalAggregate<>(groupByExpressions, outputExpressions,
disassembled, normalized, isFinalPhase,
+ aggPhase, Optional.empty(), logicalProperties,
children.get(0));
}
public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression>
groupByExprList,
List<NamedExpression> outputExpressionList) {
return new LogicalAggregate<>(groupByExprList, outputExpressionList,
- disassembled, normalized, aggPhase, child());
+ disassembled, normalized, isFinalPhase, aggPhase, child());
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
index f2384920e5..8557a61ea5 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
@@ -53,11 +53,18 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan>
extends PhysicalUnary<CH
private final boolean usingStream;
+ // use for scenes containing distinct agg
+ // 1. If there are LOCAL and GLOBAL phases, global is the final phase
+ // 2. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL
is the final phase
+ // 3. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases,
+ // DISTINCT_GLOBAL is the final phase
+ private final boolean isFinalPhase;
+
public PhysicalAggregate(List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions,
List<Expression> partitionExpressions, AggPhase aggPhase, boolean
usingStream,
- LogicalProperties logicalProperties, CHILD_TYPE child) {
+ boolean isFinalPhase, LogicalProperties logicalProperties,
CHILD_TYPE child) {
this(groupByExpressions, outputExpressions, partitionExpressions,
aggPhase, usingStream,
- Optional.empty(), logicalProperties, child);
+ isFinalPhase, Optional.empty(), logicalProperties, child);
}
/**
@@ -69,7 +76,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan>
extends PhysicalUnary<CH
* @param usingStream whether it's stream agg.
*/
public PhysicalAggregate(List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions,
- List<Expression> partitionExpressions, AggPhase aggPhase, boolean
usingStream,
+ List<Expression> partitionExpressions, AggPhase aggPhase, boolean
usingStream, boolean isFinalPhase,
Optional<GroupExpression> groupExpression, LogicalProperties
logicalProperties,
CHILD_TYPE child) {
super(PlanType.PHYSICAL_AGGREGATE, groupExpression, logicalProperties,
child);
@@ -78,6 +85,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan>
extends PhysicalUnary<CH
this.aggPhase = aggPhase;
this.partitionExpressions = partitionExpressions;
this.usingStream = usingStream;
+ this.isFinalPhase = isFinalPhase;
}
/**
@@ -89,7 +97,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan>
extends PhysicalUnary<CH
* @param usingStream whether it's stream agg.
*/
public PhysicalAggregate(List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions,
- List<Expression> partitionExpressions, AggPhase aggPhase, boolean
usingStream,
+ List<Expression> partitionExpressions, AggPhase aggPhase, boolean
usingStream, boolean isFinalPhase,
Optional<GroupExpression> groupExpression, LogicalProperties
logicalProperties,
PhysicalProperties physicalProperties, CHILD_TYPE child) {
super(PlanType.PHYSICAL_AGGREGATE, groupExpression, logicalProperties,
physicalProperties, child);
@@ -98,6 +106,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan>
extends PhysicalUnary<CH
this.aggPhase = aggPhase;
this.partitionExpressions = partitionExpressions;
this.usingStream = usingStream;
+ this.isFinalPhase = isFinalPhase;
}
public AggPhase getAggPhase() {
@@ -112,6 +121,10 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan>
extends PhysicalUnary<CH
return outputExpressions;
}
+ public boolean isFinalPhase() {
+ return isFinalPhase;
+ }
+
public boolean isUsingStream() {
return usingStream;
}
@@ -156,36 +169,38 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan>
extends PhysicalUnary<CH
&& Objects.equals(outputExpressions, that.outputExpressions)
&& Objects.equals(partitionExpressions,
that.partitionExpressions)
&& usingStream == that.usingStream
- && aggPhase == that.aggPhase;
+ && aggPhase == that.aggPhase
+ && isFinalPhase == that.isFinalPhase;
}
@Override
public int hashCode() {
- return Objects.hash(groupByExpressions, outputExpressions,
partitionExpressions, aggPhase, usingStream);
+ return Objects.hash(groupByExpressions, outputExpressions,
partitionExpressions, aggPhase, usingStream,
+ isFinalPhase);
}
@Override
public PhysicalAggregate<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new PhysicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions, aggPhase,
- usingStream, getLogicalProperties(), children.get(0));
+ usingStream, isFinalPhase, getLogicalProperties(),
children.get(0));
}
@Override
public PhysicalAggregate<CHILD_TYPE>
withGroupExpression(Optional<GroupExpression> groupExpression) {
return new PhysicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions, aggPhase,
- usingStream, groupExpression, getLogicalProperties(), child());
+ usingStream, isFinalPhase, groupExpression,
getLogicalProperties(), child());
}
@Override
public PhysicalAggregate<CHILD_TYPE>
withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new PhysicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions, aggPhase,
- usingStream, Optional.empty(), logicalProperties.get(),
child());
+ usingStream, isFinalPhase, Optional.empty(),
logicalProperties.get(), child());
}
@Override
public PhysicalAggregate<CHILD_TYPE>
withPhysicalProperties(PhysicalProperties physicalProperties) {
return new PhysicalAggregate<>(groupByExpressions, outputExpressions,
partitionExpressions, aggPhase,
- usingStream, Optional.empty(), getLogicalProperties(),
physicalProperties, child());
+ usingStream, isFinalPhase, Optional.empty(),
getLogicalProperties(), physicalProperties, child());
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java
index dd09c58a50..29ef8bb2ff 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java
@@ -360,9 +360,9 @@ public class HavingClauseTest extends AnalyzeCheckTestBase
implements PatternMat
Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk,
Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)");
Alias pk2 = new Alias(new ExprId(9), new Add(pk, Literal.of((byte)
2)), "(pk + 2)");
Alias sumA1 = new Alias(new ExprId(10), new Sum(a1), "SUM(a1)");
- Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1),
Literal.of((byte) 1)), "(COUNT(a1) + 1)");
+ Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1,
false), Literal.of((byte) 1)), "(COUNT(a1) + 1)");
Alias sumA1A2 = new Alias(new ExprId(12), new Sum(new Add(a1, a2)),
"SUM((a1 + a2))");
- Alias v1 = new Alias(new ExprId(0), new Count(a2), "v1");
+ Alias v1 = new Alias(new ExprId(0), new Count(a2, false), "v1");
PlanChecker.from(connectContext).analyze(sql)
.matchesFromRoot(
logicalProject(
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
index 08d91b777f..fe0b577cc4 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
@@ -263,6 +263,7 @@ public class ChildOutputPropertyDeriverTest {
Lists.newArrayList(key),
AggPhase.LOCAL,
true,
+ true,
logicalProperties,
groupPlan
);
@@ -286,6 +287,7 @@ public class ChildOutputPropertyDeriverTest {
Lists.newArrayList(partition),
AggPhase.GLOBAL,
true,
+ true,
logicalProperties,
groupPlan
);
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java
index dda5c0b006..9802a7d66b 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java
@@ -146,6 +146,7 @@ public class RequestPropertyDeriverTest {
Lists.newArrayList(key),
AggPhase.LOCAL,
true,
+ true,
logicalProperties,
groupPlan
);
@@ -168,6 +169,7 @@ public class RequestPropertyDeriverTest {
Lists.newArrayList(partition),
AggPhase.GLOBAL,
true,
+ true,
logicalProperties,
groupPlan
);
@@ -192,6 +194,7 @@ public class RequestPropertyDeriverTest {
Lists.newArrayList(),
AggPhase.GLOBAL,
true,
+ true,
logicalProperties,
groupPlan
);
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
index 72f4a8829a..ef32f31def 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
@@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.Count;
import org.apache.doris.nereids.trees.expressions.functions.Sum;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.AggPhase;
@@ -269,6 +270,86 @@ public class AggregateDisassembleTest {
global.getOutputExpressions().get(0).getExprId());
}
+ /**
+ * the initial plan is:
+ * Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age + 1) + 2)
as c], groupByExpr: [id + 3])
+ * +-- childPlan(id, name, age)
+ * we should rewrite to:
+ * Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [(COUNT(distinct b) +
2) as c], groupByExpr: [a])
+ * +-- Aggregate(phase: [GLOBAL], outputExpr: [a, b], groupByExpr: [a,
b])
+ * +-- Aggregate(phase: [LOCAL], outputExpr: [(id + 3) as a, (age +
1) as b], groupByExpr: [id + 3, age + 1])
+ * +-- childPlan(id, name, age)
+ */
+ @Test
+ public void distinctAggregateWithGroupBy() {
+ List<Expression> groupExpressionList = Lists.newArrayList(
+ new Add(rStudent.getOutput().get(0).toSlot(), new
IntegerLiteral(3)));
+ List<NamedExpression> outputExpressionList = Lists.newArrayList(new
Alias(
+ new Add(new Count(new
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1)), true),
+ new IntegerLiteral(2)), "c"));
+ Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList, rStudent);
+
+ Plan after = rewrite(root);
+
+ Assertions.assertTrue(after instanceof LogicalUnary);
+ Assertions.assertTrue(after instanceof LogicalAggregate);
+ Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
+ LogicalAggregate<Plan> distinctLocal = (LogicalAggregate) after;
+ LogicalAggregate<Plan> global = (LogicalAggregate) after.child(0);
+ LogicalAggregate<Plan> local = (LogicalAggregate)
after.child(0).child(0);
+ Assertions.assertEquals(AggPhase.DISTINCT_LOCAL,
distinctLocal.getAggPhase());
+ Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
+ Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
+ // check local:
+ // id + 3
+ Expression localOutput0 = new
Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3));
+ // age + 1
+ Expression localOutput1 = new
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
+ // id + 3
+ Expression localGroupBy0 = new
Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3));
+ // age + 1
+ Expression localGroupBy1 = new
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
+
+ Assertions.assertEquals(2, local.getOutputExpressions().size());
+ Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof
Alias);
+ Assertions.assertEquals(localOutput0,
local.getOutputExpressions().get(0).child(0));
+ Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof
Alias);
+ Assertions.assertEquals(localOutput1,
local.getOutputExpressions().get(1).child(0));
+ Assertions.assertEquals(2, local.getGroupByExpressions().size());
+ Assertions.assertEquals(localGroupBy0,
local.getGroupByExpressions().get(0));
+ Assertions.assertEquals(localGroupBy1,
local.getGroupByExpressions().get(1));
+
+ // check global:
+ Expression globalOutput0 =
local.getOutputExpressions().get(0).toSlot();
+ Expression globalOutput1 =
local.getOutputExpressions().get(1).toSlot();
+ Expression globalGroupBy0 =
local.getOutputExpressions().get(0).toSlot();
+ Expression globalGroupBy1 =
local.getOutputExpressions().get(1).toSlot();
+
+ Assertions.assertEquals(2, global.getOutputExpressions().size());
+ Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof
SlotReference);
+ Assertions.assertEquals(globalOutput0,
global.getOutputExpressions().get(0));
+ Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof
SlotReference);
+ Assertions.assertEquals(globalOutput1,
global.getOutputExpressions().get(1));
+ Assertions.assertEquals(2, global.getGroupByExpressions().size());
+ Assertions.assertEquals(globalGroupBy0,
global.getGroupByExpressions().get(0));
+ Assertions.assertEquals(globalGroupBy1,
global.getGroupByExpressions().get(1));
+
+ // check distinct local:
+ Expression distinctLocalOutput = new Add(new
Count(local.getOutputExpressions().get(1).toSlot(), true),
+ new IntegerLiteral(2));
+ Expression distinctLocalGroupBy =
local.getOutputExpressions().get(0).toSlot();
+
+ Assertions.assertEquals(1,
distinctLocal.getOutputExpressions().size());
+ Assertions.assertTrue(distinctLocal.getOutputExpressions().get(0)
instanceof Alias);
+ Assertions.assertEquals(distinctLocalOutput,
distinctLocal.getOutputExpressions().get(0).child(0));
+ Assertions.assertEquals(1,
distinctLocal.getGroupByExpressions().size());
+ Assertions.assertEquals(distinctLocalGroupBy,
distinctLocal.getGroupByExpressions().get(0));
+
+ // check id:
+ Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
+ distinctLocal.getOutputExpressions().get(0).getExprId());
+ }
+
private Plan rewrite(Plan input) {
return PlanRewriter.topDownRewrite(input, new ConnectContext(), new
AggregateDisassemble());
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
index 5860d95c6b..71d8248655 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.analyzer.UnboundAlias;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.analyzer.UnboundStar;
+import org.apache.doris.nereids.trees.expressions.functions.Count;
import org.apache.doris.nereids.trees.expressions.functions.Sum;
import org.apache.doris.nereids.types.IntegerType;
@@ -168,6 +169,25 @@ public class ExpressionEqualsTest {
Assertions.assertEquals(sum1.hashCode(), sum2.hashCode());
}
+ @Test
+ public void testAggregateFunction() {
+ Count count1 = new Count();
+ Count count2 = new Count();
+ Assertions.assertEquals(count1, count2);
+ Assertions.assertEquals(count1.hashCode(), count2.hashCode());
+
+ Count count3 = new Count(child1, true);
+ Count count4 = new Count(child2, true);
+ Assertions.assertEquals(count3, count4);
+ Assertions.assertEquals(count3.hashCode(), count4.hashCode());
+
+ // bad case
+ Count count5 = new Count(child1, true);
+ Count count6 = new Count(child2, false);
+ Assertions.assertNotEquals(count5, count6);
+ Assertions.assertNotEquals(count5.hashCode(), count6.hashCode());
+ }
+
@Test
public void testNamedExpression() {
ExprId aliasId = new ExprId(2);
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java
index 1d7878a2db..cdd5454e78 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java
@@ -71,17 +71,17 @@ public class PlanEqualsTest {
unexpected = new LogicalAggregate<>(Lists.newArrayList(),
ImmutableList.of(
new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE,
true, Lists.newArrayList())),
- true, false, AggPhase.GLOBAL, child);
+ true, false, true, AggPhase.GLOBAL, child);
Assertions.assertNotEquals(unexpected, actual);
unexpected = new LogicalAggregate<>(Lists.newArrayList(),
ImmutableList.of(
new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE,
true, Lists.newArrayList())),
- false, true, AggPhase.GLOBAL, child);
+ false, true, true, AggPhase.GLOBAL, child);
Assertions.assertNotEquals(unexpected, actual);
unexpected = new LogicalAggregate<>(Lists.newArrayList(),
ImmutableList.of(
new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE,
true, Lists.newArrayList())),
- false, false, AggPhase.LOCAL, child);
+ false, false, true, AggPhase.LOCAL, child);
Assertions.assertNotEquals(unexpected, actual);
}
@@ -178,20 +178,20 @@ public class PlanEqualsTest {
List<NamedExpression> outputExpressionList = ImmutableList.of(
new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE,
true, Lists.newArrayList()));
PhysicalAggregate<Plan> actual = new
PhysicalAggregate<>(Lists.newArrayList(), outputExpressionList,
- Lists.newArrayList(), AggPhase.LOCAL, true, logicalProperties,
child);
+ Lists.newArrayList(), AggPhase.LOCAL, true, true,
logicalProperties, child);
List<NamedExpression> outputExpressionList1 = ImmutableList.of(
new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE,
true, Lists.newArrayList()));
PhysicalAggregate<Plan> expected = new
PhysicalAggregate<>(Lists.newArrayList(),
outputExpressionList1,
- Lists.newArrayList(), AggPhase.LOCAL, true, logicalProperties,
child);
+ Lists.newArrayList(), AggPhase.LOCAL, true, true,
logicalProperties, child);
Assertions.assertEquals(expected, actual);
List<NamedExpression> outputExpressionList2 = ImmutableList.of(
new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE,
true, Lists.newArrayList()));
PhysicalAggregate<Plan> unexpected = new
PhysicalAggregate<>(Lists.newArrayList(),
outputExpressionList2,
- Lists.newArrayList(), AggPhase.LOCAL, false,
logicalProperties, child);
+ Lists.newArrayList(), AggPhase.LOCAL, false, true,
logicalProperties, child);
Assertions.assertNotEquals(unexpected, actual);
}
diff --git a/regression-test/data/nereids_syntax_p0/function.out
b/regression-test/data/nereids_syntax_p0/function.out
index cac9a7c5b1..b1d705b814 100644
--- a/regression-test/data/nereids_syntax_p0/function.out
+++ b/regression-test/data/nereids_syntax_p0/function.out
@@ -11,6 +11,11 @@
-- !count --
3 3
+-- !distinct_count --
+1
+1
+1
+
-- !avg --
2.5E-323 1.1644193E-317
diff --git a/regression-test/suites/nereids_syntax_p0/function.groovy
b/regression-test/suites/nereids_syntax_p0/function.groovy
index c4099a0798..a041fc36ab 100644
--- a/regression-test/suites/nereids_syntax_p0/function.groovy
+++ b/regression-test/suites/nereids_syntax_p0/function.groovy
@@ -41,6 +41,10 @@ suite("function") {
SELECT count(c_city), count(*) AS custdist FROM customer;
"""
+ order_qt_distinct_count """
+ SELECT count(distinct c_custkey + 1) AS custdist FROM customer group
by c_city;
+ """
+
order_qt_avg """
SELECT avg(lo_tax), avg(lo_extendedprice) AS avg_extendedprice FROM
lineorder;
"""
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]