This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 4f778c38a1 [feature](nereids) support explore 4 phase aggregation
(#16298)
4f778c38a1 is described below
commit 4f778c38a15024ba3f7a28b69e5888d714b4ccd2
Author: minghong <[email protected]>
AuthorDate: Fri Feb 3 21:51:10 2023 +0800
[feature](nereids) support explore 4 phase aggregation (#16298)
support 4 phase Aggregation.
example:
`select count(distinct k1), sum(k2) from t`
suppose t.k0 is distribute key.
we have plan
```
Agg(DISTINCT_GLOBAL)
|
Exchange(Gather)
|
Agg(DISTINCT_LOCAL)
|
Agg(GLOBAL)
|
Exchange(hash distribute by k1)
|
Agg(LOCAL)
|
scan
```
limitations:
1. only support sql with one distinct.
not support:`select count(distinct k1), count(distinct k2) from t`
2. only support sql with distinct one column
not support: `select count(distinct k1, k2) from t`
---
.../org/apache/doris/nereids/rules/RuleType.java | 1 +
.../nereids/rules/rewrite/AggregateStrategies.java | 159 +++++++++++++++++++++
.../data/nereids_syntax_p0/agg_4_phase.out | 4 +
.../suites/nereids_syntax_p0/agg_4_phase.groovy | 59 ++++++++
4 files changed, 223 insertions(+)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 56edfa23c4..b570a1c466 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -249,6 +249,7 @@ public enum RuleType {
TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI(RuleTypeClass.IMPLEMENTATION),
TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT(RuleTypeClass.IMPLEMENTATION),
THREE_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
+ FOUR_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
LOGICAL_UNION_TO_PHYSICAL_UNION(RuleTypeClass.IMPLEMENTATION),
LOGICAL_EXCEPT_TO_PHYSICAL_EXCEPT(RuleTypeClass.IMPLEMENTATION),
LOGICAL_INTERSECT_TO_PHYSICAL_INTERSECT(RuleTypeClass.IMPLEMENTATION),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java
index 7194c89cc8..20b261d3eb 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java
@@ -72,6 +72,7 @@ import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -157,6 +158,12 @@ public class AggregateStrategies implements
ImplementationRuleFactory {
basePattern
.when(agg -> agg.getDistinctArguments().size() > 1 &&
!containsCountDistinctMultiExpr(agg))
.thenApplyMulti(ctx ->
twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
+ ),
+ RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build(
+ basePattern
+ .when(agg -> agg.getDistinctArguments().size() ==
1)
+ .when(agg -> agg.getGroupByExpressions().isEmpty())
+ .thenApplyMulti(ctx ->
fourPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
)
);
}
@@ -1239,4 +1246,156 @@ public class AggregateStrategies implements
ImplementationRuleFactory {
ConnectContext connectContext = ConnectContext.get();
return connectContext == null ||
connectContext.getSessionVariable().enableSingleDistinctColumnOpt();
}
+
+ /**
+ * sql:
+ * select count(distinct name), sum(age) from student;
+ *
+ * 4 phase plan
+ * DISTINCT_GLOBAL, BUFFER_TO_RESULT groupBy(), output[count(name),
sum(age#5)], [GATHER]
+ * +--DISTINCT_LOCAL, INPUT_TO_BUFFER, groupBy()), output(count(name),
partial_sum(age)), hash distribute by name
+ * +--GLOBAL, BUFFER_TO_BUFFER, groupBy(name), output(name,
partial_sum(age)), hash_distribute by name
+ * +--LOCAL, INPUT_TO_BUFFER, groupBy(name), output(name,
partial_sum(age))
+ * +--scan(name, age)
+ */
+ private List<PhysicalHashAggregate<? extends Plan>>
fourPhaseAggregateWithDistinct(
+ LogicalAggregate<? extends Plan> logicalAgg, ConnectContext
connectContext) {
+ Set<AggregateFunction> aggregateFunctions =
logicalAgg.getAggregateFunctions();
+
+ Set<Expression> distinctArguments = aggregateFunctions.stream()
+ .filter(aggregateExpression ->
aggregateExpression.isDistinct())
+ .flatMap(aggregateExpression ->
aggregateExpression.getArguments().stream())
+ .collect(ImmutableSet.toImmutableSet());
+
+ Set<NamedExpression> localAggGroupBySet =
ImmutableSet.<NamedExpression>builder()
+ .addAll((List) logicalAgg.getGroupByExpressions())
+ .addAll(distinctArguments)
+ .build();
+
+ AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL,
AggMode.INPUT_TO_BUFFER);
+
+ Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 =
aggregateFunctions.stream()
+ .filter(aggregateFunction -> !aggregateFunction.isDistinct())
+ .collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> {
+ AggregateExpression localAggExpr = new
AggregateExpression(expr, inputToBufferParam);
+ return new Alias(localAggExpr, localAggExpr.toSql());
+ }));
+
+ List<NamedExpression> localAggOutput =
ImmutableList.<NamedExpression>builder()
+ .addAll(localAggGroupBySet)
+ .addAll(nonDistinctAggFunctionToAliasPhase1.values())
+ .build();
+
+ List<Expression> localAggGroupBy =
ImmutableList.copyOf(localAggGroupBySet);
+ boolean maybeUsingStreamAgg = maybeUsingStreamAgg(connectContext,
localAggGroupBy);
+ List<Expression> partitionExpressions =
getHashAggregatePartitionExpressions(logicalAgg);
+ RequireProperties requireAny =
RequireProperties.of(PhysicalProperties.ANY);
+ PhysicalHashAggregate<Plan> anyLocalAgg = new
PhysicalHashAggregate<>(localAggGroupBy,
+ localAggOutput, Optional.of(partitionExpressions),
inputToBufferParam,
+ maybeUsingStreamAgg, Optional.empty(),
logicalAgg.getLogicalProperties(),
+ requireAny, logicalAgg.child());
+
+ AggregateParam bufferToBufferParam = new
AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER);
+ Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 =
+ nonDistinctAggFunctionToAliasPhase1.entrySet()
+ .stream()
+ .collect(ImmutableMap.toImmutableMap(kv ->
kv.getKey(), kv -> {
+ AggregateFunction originFunction = kv.getKey();
+ Alias localOutput = kv.getValue();
+ AggregateExpression globalAggExpr = new
AggregateExpression(
+ originFunction, bufferToBufferParam,
localOutput.toSlot());
+ return new Alias(globalAggExpr,
globalAggExpr.toSql());
+ }));
+
+ List<NamedExpression> globalAggOutput =
ImmutableList.<NamedExpression>builder()
+ .addAll(localAggGroupBySet)
+ .addAll(nonDistinctAggFunctionToAliasPhase2.values())
+ .build();
+
+ RequireProperties requireGather =
RequireProperties.of(PhysicalProperties.GATHER);
+
+ RequireProperties requireDistinctHash = RequireProperties.of(
+
PhysicalProperties.createHash(logicalAgg.getDistinctArguments(),
ShuffleType.AGGREGATE));
+
+ //phase 2
+ PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalAgg = new
PhysicalHashAggregate<>(
+ localAggGroupBy, globalAggOutput,
Optional.of(ImmutableList.copyOf(logicalAgg.getDistinctArguments())),
+ bufferToBufferParam, false, logicalAgg.getLogicalProperties(),
+ requireDistinctHash, anyLocalAgg);
+
+ // phase 3
+ AggregateParam distinctLocalParam = new
AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER);
+ Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase3 =
new HashMap<>();
+ List<NamedExpression> localDistinctOutput = Lists.newArrayList();
+ for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
+ NamedExpression outputExpr =
logicalAgg.getOutputExpressions().get(i);
+ List<AggregateFunction> needUpdateSlot = Lists.newArrayList();
+ NamedExpression outputExprPhase3 = (NamedExpression) outputExpr
+ .rewriteDownShortCircuit(expr -> {
+ if (expr instanceof AggregateFunction) {
+ AggregateFunction aggregateFunction =
(AggregateFunction) expr;
+ if (aggregateFunction.isDistinct()) {
+
Preconditions.checkArgument(aggregateFunction.arity() == 1);
+ AggregateFunction nonDistinct =
aggregateFunction
+ .withDistinctAndChildren(false,
aggregateFunction.getArguments());
+ AggregateExpression nonDistinctAggExpr = new
AggregateExpression(nonDistinct,
+ distinctLocalParam,
aggregateFunction.child(0));
+ return nonDistinctAggExpr;
+ } else {
+ needUpdateSlot.add(aggregateFunction);
+ Alias alias =
nonDistinctAggFunctionToAliasPhase2.get(expr);
+ return new
AggregateExpression(aggregateFunction,
+ new
AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.BUFFER_TO_BUFFER),
+ alias.toSlot());
+ }
+ }
+ return expr;
+ });
+ for (AggregateFunction originFunction : needUpdateSlot) {
+ nonDistinctAggFunctionToAliasPhase3.put(originFunction,
(Alias) outputExprPhase3);
+ }
+ localDistinctOutput.add(outputExprPhase3);
+
+ }
+ PhysicalHashAggregate<? extends Plan> distinctLocal = new
PhysicalHashAggregate<>(
+ logicalAgg.getGroupByExpressions(), localDistinctOutput,
Optional.empty(),
+ distinctLocalParam, false, logicalAgg.getLogicalProperties(),
+ requireDistinctHash, anyLocalHashGlobalAgg);
+
+ //phase 4
+ AggregateParam distinctGlobalParam = new
AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT);
+ List<NamedExpression> globalDistinctOutput = Lists.newArrayList();
+ for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
+ NamedExpression outputExpr =
logicalAgg.getOutputExpressions().get(i);
+ NamedExpression outputExprPhase4 = (NamedExpression)
outputExpr.rewriteDownShortCircuit(expr -> {
+ if (expr instanceof AggregateFunction) {
+ AggregateFunction aggregateFunction = (AggregateFunction)
expr;
+ if (aggregateFunction.isDistinct()) {
+ Preconditions.checkArgument(aggregateFunction.arity()
== 1);
+ AggregateFunction nonDistinct = aggregateFunction
+ .withDistinctAndChildren(false,
aggregateFunction.getArguments());
+ int idx =
logicalAgg.getOutputExpressions().indexOf(outputExpr);
+ Alias localDistinctAlias = (Alias)
(localDistinctOutput.get(idx));
+ return new AggregateExpression(nonDistinct,
+ distinctGlobalParam,
localDistinctAlias.toSlot());
+ } else {
+ Alias alias =
nonDistinctAggFunctionToAliasPhase3.get(expr);
+ return new AggregateExpression(aggregateFunction,
+ new AggregateParam(AggPhase.DISTINCT_LOCAL,
AggMode.BUFFER_TO_RESULT),
+ alias.toSlot());
+ }
+ }
+ return expr;
+ });
+ globalDistinctOutput.add(outputExprPhase4);
+ }
+ PhysicalHashAggregate<? extends Plan> distinctGlobal = new
PhysicalHashAggregate<>(
+ logicalAgg.getGroupByExpressions(), globalDistinctOutput,
Optional.empty(),
+ distinctGlobalParam, false, logicalAgg.getLogicalProperties(),
+ requireGather, distinctLocal);
+
+ return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
+ .add(distinctGlobal)
+ .build();
+ }
}
diff --git a/regression-test/data/nereids_syntax_p0/agg_4_phase.out
b/regression-test/data/nereids_syntax_p0/agg_4_phase.out
new file mode 100644
index 0000000000..5c5dde6f85
--- /dev/null
+++ b/regression-test/data/nereids_syntax_p0/agg_4_phase.out
@@ -0,0 +1,4 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !4phase --
+3 160
+
diff --git a/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy
b/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy
new file mode 100644
index 0000000000..d2d48e3e08
--- /dev/null
+++ b/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("agg_4_phase") {
+ sql "SET enable_nereids_planner=true"
+ sql "set enable_fallback_to_original_planner=false"
+ sql "drop table if exists agg_4_phase_tbl"
+ sql """
+ CREATE TABLE agg_4_phase_tbl (
+ id int(11) NULL,
+ gender int,
+ name varchar(20),
+ age int
+ ) ENGINE=OLAP
+ DUPLICATE KEY(id)
+ COMMENT 'OLAP'
+ DISTRIBUTED BY HASH(id) BUCKETS 2
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1",
+ "in_memory" = "false",
+ "storage_format" = "V2",
+ "light_schema_change" = "true",
+ "disable_auto_compaction" = "false"
+ );
+ """
+ sql """
+ insert into agg_4_phase_tbl values
+ (0, 0, "aa", 10), (1, 1, "bb",20), (2, 2, "cc", 30), (1, 1, "bb",20),
+ (0, 0, "aa", 10), (1, 1, "bb",20), (2, 2, "cc", 30), (1, 1, "bb",20);
+ """
+ def test_sql = """
+ select
/*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT,TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/
+ count(distinct name), sum(age)
+ from agg_4_phase_tbl;
+ """
+ explain{
+ sql(test_sql)
+ contains "6:VAGGREGATE (merge finalize)"
+ contains "5:VEXCHANGE"
+ contains "4:VAGGREGATE (update serialize)"
+ contains "3:VAGGREGATE (merge serialize)"
+ contains "1:VAGGREGATE (update serialize)"
+ }
+ qt_4phase (test_sql)
+}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]