This is an automated email from the ASF dual-hosted git repository.

huajianlan pushed a commit to branch 2.1.3-ygl
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/2.1.3-ygl by this push:
     new d1d6758b107 4 phase
d1d6758b107 is described below

commit d1d6758b107e7764869f08a154739c582148d3be
Author: 924060929 <[email protected]>
AuthorDate: Tue Apr 23 18:58:16 2024 +0800

    4 phase
---
 .../properties/ChildrenPropertiesRegulator.java    |  10 -
 .../org/apache/doris/nereids/rules/RuleType.java   |   1 +
 .../rules/implementation/AggregateStrategies.java  | 204 ++++++++++++++++++++-
 3 files changed, 198 insertions(+), 17 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
index 366730f7dc5..5df1bfc0ce1 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
@@ -115,16 +115,6 @@ public class ChildrenPropertiesRegulator extends 
PlanVisitor<Boolean, Void> {
             // this means one stage gather agg, usually bad pattern
             return false;
         }
-        // forbid three or four stage distinct agg inter by distribute
-        if (agg.getAggMode() == AggMode.BUFFER_TO_BUFFER && 
children.get(0).getPlan() instanceof PhysicalDistribute) {
-            // if distinct without group by key, we prefer three or four stage 
distinct agg
-            // because the second phase of multi-distinct only have one 
instance, and it is slow generally.
-            if (agg.getGroupByExpressions().size() == 1
-                    && agg.getOutputExpressions().size() == 1) {
-                return true;
-            }
-            return false;
-        }
 
         // forbid TWO_PHASE_AGGREGATE_WITH_DISTINCT after shuffle
         // TODO: this is forbid good plan after cte reuse by mistake
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 696463523f6..e0f8ef23013 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -416,6 +416,7 @@ public enum RuleType {
     TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT(RuleTypeClass.IMPLEMENTATION),
     THREE_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
     FOUR_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
+    
FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE(RuleTypeClass.IMPLEMENTATION),
     LOGICAL_UNION_TO_PHYSICAL_UNION(RuleTypeClass.IMPLEMENTATION),
     LOGICAL_EXCEPT_TO_PHYSICAL_EXCEPT(RuleTypeClass.IMPLEMENTATION),
     LOGICAL_INTERSECT_TO_PHYSICAL_INTERSECT(RuleTypeClass.IMPLEMENTATION),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
index edbd28677b4..556ca20e47a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
@@ -292,16 +292,23 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
             //         .when(agg -> agg.getDistinctArguments().size() == 1)
             //         .thenApplyMulti(ctx -> 
twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
             // ),
+            
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE.build(
+                basePattern
+                    .when(agg -> agg.getDistinctArguments().size() == 1)
+                    .thenApplyMulti(ctx ->
+                            
fourPhaseAggregateWithDistinctAndFullDistribute(ctx.root, ctx.connectContext)
+                    )
+            ),
             RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
-                    basePattern
-                            .when(agg -> agg.getDistinctArguments().size() == 
1)
-                            .thenApplyMulti(ctx -> 
threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
+                basePattern
+                    .when(agg -> agg.getDistinctArguments().size() == 1)
+                    .thenApplyMulti(ctx -> 
threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
             ),
             RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build(
-                    basePattern
-                            .when(agg -> agg.getDistinctArguments().size() == 
1)
-                            .when(agg -> agg.getGroupByExpressions().isEmpty())
-                            .thenApplyMulti(ctx -> 
fourPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
+                basePattern
+                    .when(agg -> agg.getDistinctArguments().size() == 1)
+                    // .when(agg -> agg.getGroupByExpressions().isEmpty())
+                    .thenApplyMulti(ctx -> 
fourPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
             )
         );
     }
@@ -1831,6 +1838,189 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 .build();
     }
 
+    /**
+     * sql:
+     * select count(distinct name), sum(age) from student;
+     * <p>
+     * 4 phase plan
+     * DISTINCT_GLOBAL, BUFFER_TO_RESULT groupBy(), output[count(name), 
sum(age#5)], [GATHER]
+     * +--DISTINCT_LOCAL, INPUT_TO_BUFFER, groupBy()), output(count(name), 
partial_sum(age)), hash distribute by name
+     *    +--GLOBAL, BUFFER_TO_BUFFER, groupBy(name), output(name, 
partial_sum(age)), hash_distribute by name
+     *       +--LOCAL, INPUT_TO_BUFFER, groupBy(name), output(name, 
partial_sum(age))
+     *          +--scan(name, age)
+     */
+    private List<PhysicalHashAggregate<? extends Plan>> 
fourPhaseAggregateWithDistinctAndFullDistribute(
+            LogicalAggregate<? extends Plan> logicalAgg, ConnectContext 
connectContext) {
+        boolean couldBanned = couldConvertToMulti(logicalAgg);
+
+        Set<AggregateFunction> aggregateFunctions = 
logicalAgg.getAggregateFunctions();
+
+        Set<NamedExpression> distinctArguments = aggregateFunctions.stream()
+                .filter(AggregateFunction::isDistinct)
+                .flatMap(aggregateExpression -> 
aggregateExpression.getArguments().stream())
+                .filter(NamedExpression.class::isInstance)
+                .map(NamedExpression.class::cast)
+                .collect(ImmutableSet.toImmutableSet());
+
+        Set<NamedExpression> localAggGroupBySet = 
ImmutableSet.<NamedExpression>builder()
+                .addAll((List<NamedExpression>) (List) 
logicalAgg.getGroupByExpressions())
+                .addAll(distinctArguments)
+                .build();
+
+        AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, 
AggMode.INPUT_TO_BUFFER, couldBanned);
+
+        Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = 
aggregateFunctions.stream()
+                .filter(aggregateFunction -> !aggregateFunction.isDistinct())
+                .collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> {
+                    AggregateExpression localAggExpr = new 
AggregateExpression(expr, inputToBufferParam);
+                    return new Alias(localAggExpr);
+                }, (oldValue, newValue) -> newValue));
+
+        List<NamedExpression> localAggOutput = 
ImmutableList.<NamedExpression>builder()
+                .addAll(localAggGroupBySet)
+                .addAll(nonDistinctAggFunctionToAliasPhase1.values())
+                .build();
+
+        List<Expression> localAggGroupBy = 
ImmutableList.copyOf(localAggGroupBySet);
+        boolean maybeUsingStreamAgg = maybeUsingStreamAgg(connectContext, 
localAggGroupBy);
+        List<Expression> partitionExpressions = 
getHashAggregatePartitionExpressions(logicalAgg);
+        RequireProperties requireAny = 
RequireProperties.of(PhysicalProperties.ANY);
+
+        boolean isGroupByEmptySelectEmpty = localAggGroupBy.isEmpty() && 
localAggOutput.isEmpty();
+
+        // be not recommend generate an aggregate node with empty group by and 
empty output,
+        // so add a null int slot to group by slot and output
+        if (isGroupByEmptySelectEmpty) {
+            localAggGroupBy = ImmutableList.of(new 
NullLiteral(TinyIntType.INSTANCE));
+            localAggOutput = ImmutableList.of(new Alias(new 
NullLiteral(TinyIntType.INSTANCE)));
+        }
+
+        PhysicalHashAggregate<Plan> anyLocalAgg = new 
PhysicalHashAggregate<>(localAggGroupBy,
+                localAggOutput, Optional.of(partitionExpressions), 
inputToBufferParam,
+                maybeUsingStreamAgg, Optional.empty(), 
logicalAgg.getLogicalProperties(),
+                requireAny, logicalAgg.child());
+
+        AggregateParam bufferToBufferParam = new 
AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, couldBanned);
+        Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 =
+                nonDistinctAggFunctionToAliasPhase1.entrySet()
+                        .stream()
+                        .collect(ImmutableMap.toImmutableMap(Entry::getKey, kv 
-> {
+                            AggregateFunction originFunction = kv.getKey();
+                            Alias localOutput = kv.getValue();
+                            AggregateExpression globalAggExpr = new 
AggregateExpression(
+                                    originFunction, bufferToBufferParam, 
localOutput.toSlot());
+                            return new Alias(globalAggExpr);
+                        }));
+
+        List<NamedExpression> globalAggOutput = 
ImmutableList.<NamedExpression>builder()
+                .addAll(localAggGroupBySet)
+                .addAll(nonDistinctAggFunctionToAliasPhase2.values())
+                .build();
+
+        // be not recommend generate an aggregate node with empty group by and 
empty output,
+        // so add a null int slot to group by slot and output
+        if (isGroupByEmptySelectEmpty) {
+            globalAggOutput = ImmutableList.of(new Alias(new 
NullLiteral(TinyIntType.INSTANCE)));
+        }
+
+        RequireProperties requireGroupByAndDistinctHash = RequireProperties.of(
+                PhysicalProperties.createHash(localAggGroupBy, 
ShuffleType.REQUIRE));
+
+        //phase 2
+        PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalAgg = new 
PhysicalHashAggregate<>(
+                localAggGroupBy, globalAggOutput, 
Optional.of(ImmutableList.copyOf(logicalAgg.getDistinctArguments())),
+                bufferToBufferParam, false, logicalAgg.getLogicalProperties(),
+                requireGroupByAndDistinctHash, anyLocalAgg);
+
+        // phase 3
+        AggregateParam distinctLocalParam = new AggregateParam(
+                AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
+        Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase3 = 
new HashMap<>();
+        List<NamedExpression> localDistinctOutput = Lists.newArrayList();
+        for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
+            NamedExpression outputExpr = 
logicalAgg.getOutputExpressions().get(i);
+            List<AggregateFunction> needUpdateSlot = Lists.newArrayList();
+            NamedExpression outputExprPhase3 = (NamedExpression) outputExpr
+                    .rewriteDownShortCircuit(expr -> {
+                        if (expr instanceof AggregateFunction) {
+                            AggregateFunction aggregateFunction = 
(AggregateFunction) expr;
+                            if (aggregateFunction.isDistinct()) {
+                                Set<Expression> aggChild = 
Sets.newLinkedHashSet(aggregateFunction.children());
+                                Preconditions.checkArgument(aggChild.size() == 
1
+                                                || 
aggregateFunction.getDistinctArguments().size() == 1,
+                                        "cannot process more than one child in 
aggregate distinct function: "
+                                                + aggregateFunction);
+                                AggregateFunction nonDistinct = 
aggregateFunction
+                                        .withDistinctAndChildren(false, 
ImmutableList.copyOf(aggChild));
+                                AggregateExpression nonDistinctAggExpr = new 
AggregateExpression(nonDistinct,
+                                        distinctLocalParam, 
aggregateFunction.child(0));
+                                return nonDistinctAggExpr;
+                            } else {
+                                needUpdateSlot.add(aggregateFunction);
+                                Alias alias = 
nonDistinctAggFunctionToAliasPhase2.get(expr);
+                                return new 
AggregateExpression(aggregateFunction,
+                                        new 
AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.BUFFER_TO_BUFFER),
+                                        alias.toSlot());
+                            }
+                        }
+                        return expr;
+                    });
+            for (AggregateFunction originFunction : needUpdateSlot) {
+                nonDistinctAggFunctionToAliasPhase3.put(originFunction, 
(Alias) outputExprPhase3);
+            }
+            localDistinctOutput.add(outputExprPhase3);
+
+        }
+        PhysicalHashAggregate<? extends Plan> distinctLocal = new 
PhysicalHashAggregate<>(
+                logicalAgg.getGroupByExpressions(), localDistinctOutput, 
Optional.empty(),
+                distinctLocalParam, false, logicalAgg.getLogicalProperties(),
+                requireGroupByAndDistinctHash, anyLocalHashGlobalAgg);
+
+        //phase 4
+        AggregateParam distinctGlobalParam = new AggregateParam(
+                AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT, 
couldBanned);
+        List<NamedExpression> globalDistinctOutput = Lists.newArrayList();
+        for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
+            NamedExpression outputExpr = 
logicalAgg.getOutputExpressions().get(i);
+            NamedExpression outputExprPhase4 = (NamedExpression) 
outputExpr.rewriteDownShortCircuit(expr -> {
+                if (expr instanceof AggregateFunction) {
+                    AggregateFunction aggregateFunction = (AggregateFunction) 
expr;
+                    if (aggregateFunction.isDistinct()) {
+                        Set<Expression> aggChild = 
Sets.newLinkedHashSet(aggregateFunction.children());
+                        Preconditions.checkArgument(aggChild.size() == 1
+                                        || 
aggregateFunction.getDistinctArguments().size() == 1,
+                                "cannot process more than one child in 
aggregate distinct function: "
+                                        + aggregateFunction);
+                        AggregateFunction nonDistinct = aggregateFunction
+                                .withDistinctAndChildren(false, 
ImmutableList.copyOf(aggChild));
+                        int idx = 
logicalAgg.getOutputExpressions().indexOf(outputExpr);
+                        Alias localDistinctAlias = (Alias) 
(localDistinctOutput.get(idx));
+                        return new AggregateExpression(nonDistinct,
+                                distinctGlobalParam, 
localDistinctAlias.toSlot());
+                    } else {
+                        Alias alias = 
nonDistinctAggFunctionToAliasPhase3.get(expr);
+                        return new AggregateExpression(aggregateFunction,
+                                new AggregateParam(AggPhase.DISTINCT_LOCAL, 
AggMode.BUFFER_TO_RESULT),
+                                alias.toSlot());
+                    }
+                }
+                return expr;
+            });
+            globalDistinctOutput.add(outputExprPhase4);
+        }
+
+        RequireProperties requireGroupByHash = RequireProperties.of(
+                
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), 
ShuffleType.REQUIRE));
+        PhysicalHashAggregate<? extends Plan> distinctGlobal = new 
PhysicalHashAggregate<>(
+                logicalAgg.getGroupByExpressions(), globalDistinctOutput, 
Optional.empty(),
+                distinctGlobalParam, false, logicalAgg.getLogicalProperties(),
+                requireGroupByHash, distinctLocal);
+
+        return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
+                .add(distinctGlobal)
+                .build();
+    }
+
     private boolean couldConvertToMulti(LogicalAggregate<? extends Plan> 
aggregate) {
         Set<AggregateFunction> aggregateFunctions = 
aggregate.getAggregateFunctions();
         for (AggregateFunction func : aggregateFunctions) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to