This is an automated email from the ASF dual-hosted git repository.
huajianlan pushed a commit to branch 2.1.3-ygl
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/2.1.3-ygl by this push:
new d1d6758b107 4 phase
d1d6758b107 is described below
commit d1d6758b107e7764869f08a154739c582148d3be
Author: 924060929 <[email protected]>
AuthorDate: Tue Apr 23 18:58:16 2024 +0800
4 phase
---
.../properties/ChildrenPropertiesRegulator.java | 10 -
.../org/apache/doris/nereids/rules/RuleType.java | 1 +
.../rules/implementation/AggregateStrategies.java | 204 ++++++++++++++++++++-
3 files changed, 198 insertions(+), 17 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
index 366730f7dc5..5df1bfc0ce1 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
@@ -115,16 +115,6 @@ public class ChildrenPropertiesRegulator extends
PlanVisitor<Boolean, Void> {
// this means one stage gather agg, usually bad pattern
return false;
}
- // forbid three or four stage distinct agg inter by distribute
- if (agg.getAggMode() == AggMode.BUFFER_TO_BUFFER &&
children.get(0).getPlan() instanceof PhysicalDistribute) {
- // if distinct without group by key, we prefer three or four stage
distinct agg
- // because the second phase of multi-distinct only have one
instance, and it is slow generally.
- if (agg.getGroupByExpressions().size() == 1
- && agg.getOutputExpressions().size() == 1) {
- return true;
- }
- return false;
- }
// forbid TWO_PHASE_AGGREGATE_WITH_DISTINCT after shuffle
// TODO: this is forbid good plan after cte reuse by mistake
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 696463523f6..e0f8ef23013 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -416,6 +416,7 @@ public enum RuleType {
TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT(RuleTypeClass.IMPLEMENTATION),
THREE_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
FOUR_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
+
FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE(RuleTypeClass.IMPLEMENTATION),
LOGICAL_UNION_TO_PHYSICAL_UNION(RuleTypeClass.IMPLEMENTATION),
LOGICAL_EXCEPT_TO_PHYSICAL_EXCEPT(RuleTypeClass.IMPLEMENTATION),
LOGICAL_INTERSECT_TO_PHYSICAL_INTERSECT(RuleTypeClass.IMPLEMENTATION),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
index edbd28677b4..556ca20e47a 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
@@ -292,16 +292,23 @@ public class AggregateStrategies implements
ImplementationRuleFactory {
// .when(agg -> agg.getDistinctArguments().size() == 1)
// .thenApplyMulti(ctx ->
twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
// ),
+
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE.build(
+ basePattern
+ .when(agg -> agg.getDistinctArguments().size() == 1)
+ .thenApplyMulti(ctx ->
+
fourPhaseAggregateWithDistinctAndFullDistribute(ctx.root, ctx.connectContext)
+ )
+ ),
RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
- basePattern
- .when(agg -> agg.getDistinctArguments().size() ==
1)
- .thenApplyMulti(ctx ->
threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
+ basePattern
+ .when(agg -> agg.getDistinctArguments().size() == 1)
+ .thenApplyMulti(ctx ->
threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
),
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build(
- basePattern
- .when(agg -> agg.getDistinctArguments().size() ==
1)
- .when(agg -> agg.getGroupByExpressions().isEmpty())
- .thenApplyMulti(ctx ->
fourPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
+ basePattern
+ .when(agg -> agg.getDistinctArguments().size() == 1)
+ // .when(agg -> agg.getGroupByExpressions().isEmpty())
+ .thenApplyMulti(ctx ->
fourPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
)
);
}
@@ -1831,6 +1838,189 @@ public class AggregateStrategies implements
ImplementationRuleFactory {
.build();
}
+ /**
+ * sql:
+ * select count(distinct name), sum(age) from student;
+ * <p>
+ * 4 phase plan
+ * DISTINCT_GLOBAL, BUFFER_TO_RESULT groupBy(), output[count(name),
sum(age#5)], [GATHER]
+ * +--DISTINCT_LOCAL, INPUT_TO_BUFFER, groupBy()), output(count(name),
partial_sum(age)), hash distribute by name
+ * +--GLOBAL, BUFFER_TO_BUFFER, groupBy(name), output(name,
partial_sum(age)), hash_distribute by name
+ * +--LOCAL, INPUT_TO_BUFFER, groupBy(name), output(name,
partial_sum(age))
+ * +--scan(name, age)
+ */
+ private List<PhysicalHashAggregate<? extends Plan>>
fourPhaseAggregateWithDistinctAndFullDistribute(
+ LogicalAggregate<? extends Plan> logicalAgg, ConnectContext
connectContext) {
+ boolean couldBanned = couldConvertToMulti(logicalAgg);
+
+ Set<AggregateFunction> aggregateFunctions =
logicalAgg.getAggregateFunctions();
+
+ Set<NamedExpression> distinctArguments = aggregateFunctions.stream()
+ .filter(AggregateFunction::isDistinct)
+ .flatMap(aggregateExpression ->
aggregateExpression.getArguments().stream())
+ .filter(NamedExpression.class::isInstance)
+ .map(NamedExpression.class::cast)
+ .collect(ImmutableSet.toImmutableSet());
+
+ Set<NamedExpression> localAggGroupBySet =
ImmutableSet.<NamedExpression>builder()
+ .addAll((List<NamedExpression>) (List)
logicalAgg.getGroupByExpressions())
+ .addAll(distinctArguments)
+ .build();
+
+ AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL,
AggMode.INPUT_TO_BUFFER, couldBanned);
+
+ Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 =
aggregateFunctions.stream()
+ .filter(aggregateFunction -> !aggregateFunction.isDistinct())
+ .collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> {
+ AggregateExpression localAggExpr = new
AggregateExpression(expr, inputToBufferParam);
+ return new Alias(localAggExpr);
+ }, (oldValue, newValue) -> newValue));
+
+ List<NamedExpression> localAggOutput =
ImmutableList.<NamedExpression>builder()
+ .addAll(localAggGroupBySet)
+ .addAll(nonDistinctAggFunctionToAliasPhase1.values())
+ .build();
+
+ List<Expression> localAggGroupBy =
ImmutableList.copyOf(localAggGroupBySet);
+ boolean maybeUsingStreamAgg = maybeUsingStreamAgg(connectContext,
localAggGroupBy);
+ List<Expression> partitionExpressions =
getHashAggregatePartitionExpressions(logicalAgg);
+ RequireProperties requireAny =
RequireProperties.of(PhysicalProperties.ANY);
+
+ boolean isGroupByEmptySelectEmpty = localAggGroupBy.isEmpty() &&
localAggOutput.isEmpty();
+
+ // be not recommend generate an aggregate node with empty group by and
empty output,
+ // so add a null int slot to group by slot and output
+ if (isGroupByEmptySelectEmpty) {
+ localAggGroupBy = ImmutableList.of(new
NullLiteral(TinyIntType.INSTANCE));
+ localAggOutput = ImmutableList.of(new Alias(new
NullLiteral(TinyIntType.INSTANCE)));
+ }
+
+ PhysicalHashAggregate<Plan> anyLocalAgg = new
PhysicalHashAggregate<>(localAggGroupBy,
+ localAggOutput, Optional.of(partitionExpressions),
inputToBufferParam,
+ maybeUsingStreamAgg, Optional.empty(),
logicalAgg.getLogicalProperties(),
+ requireAny, logicalAgg.child());
+
+ AggregateParam bufferToBufferParam = new
AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, couldBanned);
+ Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 =
+ nonDistinctAggFunctionToAliasPhase1.entrySet()
+ .stream()
+ .collect(ImmutableMap.toImmutableMap(Entry::getKey, kv
-> {
+ AggregateFunction originFunction = kv.getKey();
+ Alias localOutput = kv.getValue();
+ AggregateExpression globalAggExpr = new
AggregateExpression(
+ originFunction, bufferToBufferParam,
localOutput.toSlot());
+ return new Alias(globalAggExpr);
+ }));
+
+ List<NamedExpression> globalAggOutput =
ImmutableList.<NamedExpression>builder()
+ .addAll(localAggGroupBySet)
+ .addAll(nonDistinctAggFunctionToAliasPhase2.values())
+ .build();
+
+ // be not recommend generate an aggregate node with empty group by and
empty output,
+ // so add a null int slot to group by slot and output
+ if (isGroupByEmptySelectEmpty) {
+ globalAggOutput = ImmutableList.of(new Alias(new
NullLiteral(TinyIntType.INSTANCE)));
+ }
+
+ RequireProperties requireGroupByAndDistinctHash = RequireProperties.of(
+ PhysicalProperties.createHash(localAggGroupBy,
ShuffleType.REQUIRE));
+
+ //phase 2
+ PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalAgg = new
PhysicalHashAggregate<>(
+ localAggGroupBy, globalAggOutput,
Optional.of(ImmutableList.copyOf(logicalAgg.getDistinctArguments())),
+ bufferToBufferParam, false, logicalAgg.getLogicalProperties(),
+ requireGroupByAndDistinctHash, anyLocalAgg);
+
+ // phase 3
+ AggregateParam distinctLocalParam = new AggregateParam(
+ AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
+ Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase3 =
new HashMap<>();
+ List<NamedExpression> localDistinctOutput = Lists.newArrayList();
+ for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
+ NamedExpression outputExpr =
logicalAgg.getOutputExpressions().get(i);
+ List<AggregateFunction> needUpdateSlot = Lists.newArrayList();
+ NamedExpression outputExprPhase3 = (NamedExpression) outputExpr
+ .rewriteDownShortCircuit(expr -> {
+ if (expr instanceof AggregateFunction) {
+ AggregateFunction aggregateFunction =
(AggregateFunction) expr;
+ if (aggregateFunction.isDistinct()) {
+ Set<Expression> aggChild =
Sets.newLinkedHashSet(aggregateFunction.children());
+ Preconditions.checkArgument(aggChild.size() ==
1
+ ||
aggregateFunction.getDistinctArguments().size() == 1,
+ "cannot process more than one child in
aggregate distinct function: "
+ + aggregateFunction);
+ AggregateFunction nonDistinct =
aggregateFunction
+ .withDistinctAndChildren(false,
ImmutableList.copyOf(aggChild));
+ AggregateExpression nonDistinctAggExpr = new
AggregateExpression(nonDistinct,
+ distinctLocalParam,
aggregateFunction.child(0));
+ return nonDistinctAggExpr;
+ } else {
+ needUpdateSlot.add(aggregateFunction);
+ Alias alias =
nonDistinctAggFunctionToAliasPhase2.get(expr);
+ return new
AggregateExpression(aggregateFunction,
+ new
AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.BUFFER_TO_BUFFER),
+ alias.toSlot());
+ }
+ }
+ return expr;
+ });
+ for (AggregateFunction originFunction : needUpdateSlot) {
+ nonDistinctAggFunctionToAliasPhase3.put(originFunction,
(Alias) outputExprPhase3);
+ }
+ localDistinctOutput.add(outputExprPhase3);
+
+ }
+ PhysicalHashAggregate<? extends Plan> distinctLocal = new
PhysicalHashAggregate<>(
+ logicalAgg.getGroupByExpressions(), localDistinctOutput,
Optional.empty(),
+ distinctLocalParam, false, logicalAgg.getLogicalProperties(),
+ requireGroupByAndDistinctHash, anyLocalHashGlobalAgg);
+
+ //phase 4
+ AggregateParam distinctGlobalParam = new AggregateParam(
+ AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT,
couldBanned);
+ List<NamedExpression> globalDistinctOutput = Lists.newArrayList();
+ for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
+ NamedExpression outputExpr =
logicalAgg.getOutputExpressions().get(i);
+ NamedExpression outputExprPhase4 = (NamedExpression)
outputExpr.rewriteDownShortCircuit(expr -> {
+ if (expr instanceof AggregateFunction) {
+ AggregateFunction aggregateFunction = (AggregateFunction)
expr;
+ if (aggregateFunction.isDistinct()) {
+ Set<Expression> aggChild =
Sets.newLinkedHashSet(aggregateFunction.children());
+ Preconditions.checkArgument(aggChild.size() == 1
+ ||
aggregateFunction.getDistinctArguments().size() == 1,
+ "cannot process more than one child in
aggregate distinct function: "
+ + aggregateFunction);
+ AggregateFunction nonDistinct = aggregateFunction
+ .withDistinctAndChildren(false,
ImmutableList.copyOf(aggChild));
+ int idx =
logicalAgg.getOutputExpressions().indexOf(outputExpr);
+ Alias localDistinctAlias = (Alias)
(localDistinctOutput.get(idx));
+ return new AggregateExpression(nonDistinct,
+ distinctGlobalParam,
localDistinctAlias.toSlot());
+ } else {
+ Alias alias =
nonDistinctAggFunctionToAliasPhase3.get(expr);
+ return new AggregateExpression(aggregateFunction,
+ new AggregateParam(AggPhase.DISTINCT_LOCAL,
AggMode.BUFFER_TO_RESULT),
+ alias.toSlot());
+ }
+ }
+ return expr;
+ });
+ globalDistinctOutput.add(outputExprPhase4);
+ }
+
+ RequireProperties requireGroupByHash = RequireProperties.of(
+
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(),
ShuffleType.REQUIRE));
+ PhysicalHashAggregate<? extends Plan> distinctGlobal = new
PhysicalHashAggregate<>(
+ logicalAgg.getGroupByExpressions(), globalDistinctOutput,
Optional.empty(),
+ distinctGlobalParam, false, logicalAgg.getLogicalProperties(),
+ requireGroupByHash, distinctLocal);
+
+ return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
+ .add(distinctGlobal)
+ .build();
+ }
+
private boolean couldConvertToMulti(LogicalAggregate<? extends Plan>
aggregate) {
Set<AggregateFunction> aggregateFunctions =
aggregate.getAggregateFunctions();
for (AggregateFunction func : aggregateFunctions) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]