This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 4f778c38a1 [feature](nereids) support explore 4 phase aggregation 
(#16298)
4f778c38a1 is described below

commit 4f778c38a15024ba3f7a28b69e5888d714b4ccd2
Author: minghong <[email protected]>
AuthorDate: Fri Feb 3 21:51:10 2023 +0800

    [feature](nereids) support explore 4 phase aggregation (#16298)
    
    support 4 phase Aggregation.
    example:
    `select count(distinct k1), sum(k2) from t`
    suppose t.k0 is distribute key.
    
    we have plan
    ```
    Agg(DISTINCT_GLOBAL)
       |
    Exchange(Gather)
      |
    Agg(DISTINCT_LOCAL)
      |
    Agg(GLOBAL)
      |
    Exchange(hash distribute by k1)
     |
    Agg(LOCAL)
     |
    scan
    ```
    
    limitations:
    1. only support sql with one distinct.
    not support:`select count(distinct k1), count(distinct k2) from t`
    2. only support sql with distinct one column
    not support: `select count(distinct k1, k2) from t`
---
 .../org/apache/doris/nereids/rules/RuleType.java   |   1 +
 .../nereids/rules/rewrite/AggregateStrategies.java | 159 +++++++++++++++++++++
 .../data/nereids_syntax_p0/agg_4_phase.out         |   4 +
 .../suites/nereids_syntax_p0/agg_4_phase.groovy    |  59 ++++++++
 4 files changed, 223 insertions(+)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 56edfa23c4..b570a1c466 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -249,6 +249,7 @@ public enum RuleType {
     TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI(RuleTypeClass.IMPLEMENTATION),
     TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT(RuleTypeClass.IMPLEMENTATION),
     THREE_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
+    FOUR_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
     LOGICAL_UNION_TO_PHYSICAL_UNION(RuleTypeClass.IMPLEMENTATION),
     LOGICAL_EXCEPT_TO_PHYSICAL_EXCEPT(RuleTypeClass.IMPLEMENTATION),
     LOGICAL_INTERSECT_TO_PHYSICAL_INTERSECT(RuleTypeClass.IMPLEMENTATION),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java
index 7194c89cc8..20b261d3eb 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java
@@ -72,6 +72,7 @@ import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
 
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -157,6 +158,12 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 basePattern
                     .when(agg -> agg.getDistinctArguments().size() > 1 && 
!containsCountDistinctMultiExpr(agg))
                     .thenApplyMulti(ctx -> 
twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
+            ),
+            RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build(
+                    basePattern
+                            .when(agg -> agg.getDistinctArguments().size() == 
1)
+                            .when(agg -> agg.getGroupByExpressions().isEmpty())
+                            .thenApplyMulti(ctx -> 
fourPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
             )
         );
     }
@@ -1239,4 +1246,156 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
         ConnectContext connectContext = ConnectContext.get();
         return connectContext == null || 
connectContext.getSessionVariable().enableSingleDistinctColumnOpt();
     }
+
+    /**
+     * sql:
+     * select count(distinct name), sum(age) from student;
+     *
+     * 4 phase plan
+     * DISTINCT_GLOBAL, BUFFER_TO_RESULT groupBy(), output[count(name), 
sum(age#5)], [GATHER]
+     * +--DISTINCT_LOCAL, INPUT_TO_BUFFER, groupBy()), output(count(name), 
partial_sum(age)), hash distribute by name
+     *    +--GLOBAL, BUFFER_TO_BUFFER, groupBy(name), output(name, 
partial_sum(age)), hash_distribute by name
+     *       +--LOCAL, INPUT_TO_BUFFER, groupBy(name), output(name, 
partial_sum(age))
+     *          +--scan(name, age)
+     */
+    private List<PhysicalHashAggregate<? extends Plan>> 
fourPhaseAggregateWithDistinct(
+            LogicalAggregate<? extends Plan> logicalAgg, ConnectContext 
connectContext) {
+        Set<AggregateFunction> aggregateFunctions = 
logicalAgg.getAggregateFunctions();
+
+        Set<Expression> distinctArguments = aggregateFunctions.stream()
+                .filter(aggregateExpression -> 
aggregateExpression.isDistinct())
+                .flatMap(aggregateExpression -> 
aggregateExpression.getArguments().stream())
+                .collect(ImmutableSet.toImmutableSet());
+
+        Set<NamedExpression> localAggGroupBySet = 
ImmutableSet.<NamedExpression>builder()
+                .addAll((List) logicalAgg.getGroupByExpressions())
+                .addAll(distinctArguments)
+                .build();
+
+        AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, 
AggMode.INPUT_TO_BUFFER);
+
+        Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = 
aggregateFunctions.stream()
+                .filter(aggregateFunction -> !aggregateFunction.isDistinct())
+                .collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> {
+                    AggregateExpression localAggExpr = new 
AggregateExpression(expr, inputToBufferParam);
+                    return new Alias(localAggExpr, localAggExpr.toSql());
+                }));
+
+        List<NamedExpression> localAggOutput = 
ImmutableList.<NamedExpression>builder()
+                .addAll(localAggGroupBySet)
+                .addAll(nonDistinctAggFunctionToAliasPhase1.values())
+                .build();
+
+        List<Expression> localAggGroupBy = 
ImmutableList.copyOf(localAggGroupBySet);
+        boolean maybeUsingStreamAgg = maybeUsingStreamAgg(connectContext, 
localAggGroupBy);
+        List<Expression> partitionExpressions = 
getHashAggregatePartitionExpressions(logicalAgg);
+        RequireProperties requireAny = 
RequireProperties.of(PhysicalProperties.ANY);
+        PhysicalHashAggregate<Plan> anyLocalAgg = new 
PhysicalHashAggregate<>(localAggGroupBy,
+                localAggOutput, Optional.of(partitionExpressions), 
inputToBufferParam,
+                maybeUsingStreamAgg, Optional.empty(), 
logicalAgg.getLogicalProperties(),
+                requireAny, logicalAgg.child());
+
+        AggregateParam bufferToBufferParam = new 
AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER);
+        Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 =
+                nonDistinctAggFunctionToAliasPhase1.entrySet()
+                        .stream()
+                        .collect(ImmutableMap.toImmutableMap(kv -> 
kv.getKey(), kv -> {
+                            AggregateFunction originFunction = kv.getKey();
+                            Alias localOutput = kv.getValue();
+                            AggregateExpression globalAggExpr = new 
AggregateExpression(
+                                    originFunction, bufferToBufferParam, 
localOutput.toSlot());
+                            return new Alias(globalAggExpr, 
globalAggExpr.toSql());
+                        }));
+
+        List<NamedExpression> globalAggOutput = 
ImmutableList.<NamedExpression>builder()
+                .addAll(localAggGroupBySet)
+                .addAll(nonDistinctAggFunctionToAliasPhase2.values())
+                .build();
+
+        RequireProperties requireGather = 
RequireProperties.of(PhysicalProperties.GATHER);
+
+        RequireProperties requireDistinctHash = RequireProperties.of(
+                
PhysicalProperties.createHash(logicalAgg.getDistinctArguments(), 
ShuffleType.AGGREGATE));
+
+        //phase 2
+        PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalAgg = new 
PhysicalHashAggregate<>(
+                localAggGroupBy, globalAggOutput, 
Optional.of(ImmutableList.copyOf(logicalAgg.getDistinctArguments())),
+                bufferToBufferParam, false, logicalAgg.getLogicalProperties(),
+                requireDistinctHash, anyLocalAgg);
+
+        // phase 3
+        AggregateParam distinctLocalParam = new 
AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER);
+        Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase3 = 
new HashMap<>();
+        List<NamedExpression> localDistinctOutput = Lists.newArrayList();
+        for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
+            NamedExpression outputExpr = 
logicalAgg.getOutputExpressions().get(i);
+            List<AggregateFunction> needUpdateSlot = Lists.newArrayList();
+            NamedExpression outputExprPhase3 = (NamedExpression) outputExpr
+                    .rewriteDownShortCircuit(expr -> {
+                        if (expr instanceof AggregateFunction) {
+                            AggregateFunction aggregateFunction = 
(AggregateFunction) expr;
+                            if (aggregateFunction.isDistinct()) {
+                                
Preconditions.checkArgument(aggregateFunction.arity() == 1);
+                                AggregateFunction nonDistinct = 
aggregateFunction
+                                        .withDistinctAndChildren(false, 
aggregateFunction.getArguments());
+                                AggregateExpression nonDistinctAggExpr = new 
AggregateExpression(nonDistinct,
+                                        distinctLocalParam, 
aggregateFunction.child(0));
+                                return nonDistinctAggExpr;
+                            } else {
+                                needUpdateSlot.add(aggregateFunction);
+                                Alias alias = 
nonDistinctAggFunctionToAliasPhase2.get(expr);
+                                return new 
AggregateExpression(aggregateFunction,
+                                        new 
AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.BUFFER_TO_BUFFER),
+                                        alias.toSlot());
+                            }
+                        }
+                        return expr;
+                    });
+            for (AggregateFunction originFunction : needUpdateSlot) {
+                nonDistinctAggFunctionToAliasPhase3.put(originFunction, 
(Alias) outputExprPhase3);
+            }
+            localDistinctOutput.add(outputExprPhase3);
+
+        }
+        PhysicalHashAggregate<? extends Plan> distinctLocal = new 
PhysicalHashAggregate<>(
+                logicalAgg.getGroupByExpressions(), localDistinctOutput, 
Optional.empty(),
+                distinctLocalParam, false, logicalAgg.getLogicalProperties(),
+                requireDistinctHash, anyLocalHashGlobalAgg);
+
+        //phase 4
+        AggregateParam distinctGlobalParam = new 
AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT);
+        List<NamedExpression> globalDistinctOutput = Lists.newArrayList();
+        for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
+            NamedExpression outputExpr = 
logicalAgg.getOutputExpressions().get(i);
+            NamedExpression outputExprPhase4 = (NamedExpression) 
outputExpr.rewriteDownShortCircuit(expr -> {
+                if (expr instanceof AggregateFunction) {
+                    AggregateFunction aggregateFunction = (AggregateFunction) 
expr;
+                    if (aggregateFunction.isDistinct()) {
+                        Preconditions.checkArgument(aggregateFunction.arity() 
== 1);
+                        AggregateFunction nonDistinct = aggregateFunction
+                                .withDistinctAndChildren(false, 
aggregateFunction.getArguments());
+                        int idx = 
logicalAgg.getOutputExpressions().indexOf(outputExpr);
+                        Alias localDistinctAlias = (Alias) 
(localDistinctOutput.get(idx));
+                        return new AggregateExpression(nonDistinct,
+                                distinctGlobalParam, 
localDistinctAlias.toSlot());
+                    } else {
+                        Alias alias = 
nonDistinctAggFunctionToAliasPhase3.get(expr);
+                        return new AggregateExpression(aggregateFunction,
+                                new AggregateParam(AggPhase.DISTINCT_LOCAL, 
AggMode.BUFFER_TO_RESULT),
+                                alias.toSlot());
+                    }
+                }
+                return expr;
+            });
+            globalDistinctOutput.add(outputExprPhase4);
+        }
+        PhysicalHashAggregate<? extends Plan> distinctGlobal = new 
PhysicalHashAggregate<>(
+                logicalAgg.getGroupByExpressions(), globalDistinctOutput, 
Optional.empty(),
+                distinctGlobalParam, false, logicalAgg.getLogicalProperties(),
+                requireGather, distinctLocal);
+
+        return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
+                .add(distinctGlobal)
+                .build();
+    }
 }
diff --git a/regression-test/data/nereids_syntax_p0/agg_4_phase.out 
b/regression-test/data/nereids_syntax_p0/agg_4_phase.out
new file mode 100644
index 0000000000..5c5dde6f85
--- /dev/null
+++ b/regression-test/data/nereids_syntax_p0/agg_4_phase.out
@@ -0,0 +1,4 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !4phase --
+3      160
+
diff --git a/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy 
b/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy
new file mode 100644
index 0000000000..d2d48e3e08
--- /dev/null
+++ b/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("agg_4_phase") {
+    sql "SET enable_nereids_planner=true"
+    sql "set enable_fallback_to_original_planner=false"
+    sql "drop table if exists agg_4_phase_tbl"
+    sql """
+        CREATE TABLE agg_4_phase_tbl (
+            id int(11) NULL,
+            gender int,
+            name varchar(20),
+            age int
+        ) ENGINE=OLAP
+        DUPLICATE KEY(id)
+        COMMENT 'OLAP'
+        DISTRIBUTED BY HASH(id) BUCKETS 2
+        PROPERTIES (
+            "replication_allocation" = "tag.location.default: 1",
+            "in_memory" = "false",
+            "storage_format" = "V2",
+            "light_schema_change" = "true",
+            "disable_auto_compaction" = "false"
+        ); 
+        """
+    sql """
+        insert into agg_4_phase_tbl values 
+        (0, 0, "aa", 10), (1, 1, "bb",20), (2, 2, "cc", 30), (1, 1, "bb",20),
+        (0, 0, "aa", 10), (1, 1, "bb",20), (2, 2, "cc", 30), (1, 1, "bb",20);
+    """
+    def test_sql = """
+        select 
/*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT,TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/
 
+            count(distinct name), sum(age) 
+        from agg_4_phase_tbl;
+        """
+    explain{
+        sql(test_sql)
+        contains "6:VAGGREGATE (merge finalize)"
+        contains "5:VEXCHANGE"
+        contains "4:VAGGREGATE (update serialize)"
+        contains "3:VAGGREGATE (merge serialize)"
+        contains "1:VAGGREGATE (update serialize)"
+    }
+    qt_4phase (test_sql)
+}
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to