This is an automated email from the ASF dual-hosted git repository.
kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.0 by this push:
new 5c9f07840ad [refactor](nereids)make NormalizeAggregate rule more clear
and readable #28607 (#28828)
5c9f07840ad is described below
commit 5c9f07840ad9ebacf803aebac3e310e79b5f7763
Author: starocean999 <[email protected]>
AuthorDate: Fri Dec 22 11:34:04 2023 +0800
[refactor](nereids)make NormalizeAggregate rule more clear and readable
#28607 (#28828)
---
.../nereids/rules/analysis/NormalizeAggregate.java | 189 ++++++++++++---------
.../aggregate/agg_distinct_case_when.groovy | 54 ++++++
.../window_functions/test_window_fn.groovy | 6 +-
3 files changed, 160 insertions(+), 89 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
index dd7f1197b11..a7eb7c7e5cc 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
@@ -41,9 +41,9 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -101,23 +101,94 @@ public class NormalizeAggregate extends
OneRewriteRuleFactory implements Normali
@Override
public Rule build() {
return
logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
+ // The LogicalAggregate node may contain window agg functions and
usual agg functions
+ // we call window agg functions as window-agg and usual agg
functions as trival-agg for short
+ // This rule simplify LogicalAggregate node by:
+ // 1. Push down some exprs from old LogicalAggregate node to a new
child LogicalProject Node,
+ // 2. create a new LogicalAggregate with normalized group by exprs
and trival-aggs
+ // 3. Pull up normalized old LogicalAggregate's output exprs to a
new parent LogicalProject Node
+ // Push down exprs:
+ // 1. all group by exprs
+ // 2. child contains subquery expr in trival-agg
+ // 3. child contains window expr in trival-agg
+ // 4. all input slots of trival-agg
+ // 5. expr(including subquery) in distinct trival-agg
+ // Normalize LogicalAggregate's output.
+ // 1. normalize group by exprs by outputs of bottom LogicalProject
+ // 2. normalize trival-aggs by outputs of bottom LogicalProject
+ // 3. build normalized agg outputs
+ // Pull up exprs:
+ // normalize all output exprs in old LogicalAggregate to build a
parent project node, typically includes:
+ // 1. simple slots
+ // 2. aliases
+ // a. alias with no aggs child
+ // b. alias with trival-agg child
+ // c. alias with window-agg
- List<NamedExpression> aggregateOutput =
aggregate.getOutputExpressions();
- Set<Alias> existsAlias =
ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
+ // Push down exprs:
+ // collect group by exprs
+ Set<Expression> groupingByExprs =
+ ImmutableSet.copyOf(aggregate.getGroupByExpressions());
+ // collect all trival-agg
+ List<NamedExpression> aggregateOutput =
aggregate.getOutputExpressions();
List<AggregateFunction> aggFuncs = Lists.newArrayList();
aggregateOutput.forEach(o ->
o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
- // we need push down subquery exprs inside non-window and
non-distinct agg functions
- // because the distinct agg's children would be pushed down in
later process
- Set<SubqueryExpr> subqueryExprs =
ExpressionUtils.mutableCollect(aggFuncs.stream()
- .filter(aggFunc ->
!aggFunc.isDistinct()).collect(Collectors.toList()),
- SubqueryExpr.class::isInstance);
- Set<Expression> groupingByExprs =
ImmutableSet.copyOf(aggregate.getGroupByExpressions());
+ // split non-distinct agg child as two part
+ // TRUE part 1: need push down itself, if it contains subqury or
window expression
+ // FALSE part 2: need push down its input slots, if it DOES NOT
contain subqury or window expression
+ Map<Boolean, Set<Expression>> categorizedNoDistinctAggsChildren =
aggFuncs.stream()
+ .filter(aggFunc -> !aggFunc.isDistinct())
+ .flatMap(agg -> agg.children().stream())
+ .collect(Collectors.groupingBy(
+ child -> child.containsType(SubqueryExpr.class,
WindowExpression.class),
+ Collectors.toSet()));
+
+ // split distinct agg child as two parts
+ // TRUE part 1: need push down itself, if it is NOT SlotReference
or Literal
+ // FALSE part 2: need push down its input slots, if it is
SlotReference or Literal
+ Map<Boolean, Set<Expression>> categorizedDistinctAggsChildren =
aggFuncs.stream()
+ .filter(aggFunc -> aggFunc.isDistinct()).flatMap(agg ->
agg.children().stream())
+ .collect(Collectors.groupingBy(
+ child -> !(child instanceof SlotReference || child
instanceof Literal),
+ Collectors.toSet()));
+
+ Set<Expression> needPushSelf = Sets.union(
+ categorizedNoDistinctAggsChildren.getOrDefault(true, new
HashSet<>()),
+ categorizedDistinctAggsChildren.getOrDefault(true, new
HashSet<>()));
+ Set<Slot> needPushInputSlots =
ExpressionUtils.getInputSlotSet(Sets.union(
+ categorizedNoDistinctAggsChildren.getOrDefault(false, new
HashSet<>()),
+ categorizedDistinctAggsChildren.getOrDefault(false, new
HashSet<>())));
+
+ Set<Alias> existsAlias =
+ ExpressionUtils.mutableCollect(aggregateOutput,
Alias.class::isInstance);
+
+ // push down 3 kinds of exprs, these pushed exprs will be used to
normalize agg output later
+ // 1. group by exprs
+ // 2. trivalAgg children
+ // 3. trivalAgg input slots
+ Set<Expression> allPushDownExprs =
+ Sets.union(groupingByExprs, Sets.union(needPushSelf,
needPushInputSlots));
NormalizeToSlotContext bottomSlotContext =
- NormalizeToSlotContext.buildContext(existsAlias,
Sets.union(groupingByExprs, subqueryExprs));
- Set<NamedExpression> bottomOutputs =
-
bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs,
subqueryExprs));
+ NormalizeToSlotContext.buildContext(existsAlias,
allPushDownExprs);
+ Set<NamedExpression> pushedGroupByExprs =
+
bottomSlotContext.pushDownToNamedExpression(groupingByExprs);
+ Set<NamedExpression> pushedTrivalAggChildren =
+ bottomSlotContext.pushDownToNamedExpression(needPushSelf);
+ Set<NamedExpression> pushedTrivalAggInputSlots =
+
bottomSlotContext.pushDownToNamedExpression(needPushInputSlots);
+ Set<NamedExpression> bottomProjects =
Sets.union(pushedGroupByExprs,
+ Sets.union(pushedTrivalAggChildren,
pushedTrivalAggInputSlots));
+
+ // create bottom project
+ Plan bottomPlan;
+ if (!bottomProjects.isEmpty()) {
+ bottomPlan = new
LogicalProject<>(ImmutableList.copyOf(bottomProjects),
+ aggregate.child());
+ } else {
+ bottomPlan = aggregate.child();
+ }
// use group by context to normalize agg functions to process
// sql like: select sum(a + 1) from t group by a + 1
@@ -129,89 +200,37 @@ public class NormalizeAggregate extends
OneRewriteRuleFactory implements Normali
// after normalize:
// agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a +
1)[#1])
// +-- project((a[#0] + 1)[#1])
- List<AggregateFunction> normalizedAggFuncs =
bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
- Set<NamedExpression> bottomProjects =
Sets.newHashSet(bottomOutputs);
- // TODO: if we have distinct agg, we must push down its children,
- // because need use it to generate distribution enforce
- // step 1: split agg functions into 2 parts: distinct and not
distinct
- List<AggregateFunction> distinctAggFuncs = Lists.newArrayList();
- List<AggregateFunction> nonDistinctAggFuncs = Lists.newArrayList();
- for (AggregateFunction aggregateFunction : normalizedAggFuncs) {
- if (aggregateFunction.isDistinct()) {
- distinctAggFuncs.add(aggregateFunction);
- } else {
- nonDistinctAggFuncs.add(aggregateFunction);
- }
- }
- // step 2: if we only have one distinct agg function, we do push
down for it
- if (!distinctAggFuncs.isEmpty()) {
- // process distinct normalize and put it back to
normalizedAggFuncs
- List<AggregateFunction> newDistinctAggFuncs =
Lists.newArrayList();
- Map<Expression, Expression> replaceMap = Maps.newHashMap();
- Map<Expression, NamedExpression> aliasCache =
Maps.newHashMap();
- for (AggregateFunction distinctAggFunc : distinctAggFuncs) {
- List<Expression> newChildren = Lists.newArrayList();
- for (Expression child : distinctAggFunc.children()) {
- if (child instanceof SlotReference || child instanceof
Literal) {
- newChildren.add(child);
- } else {
- NamedExpression alias;
- if (aliasCache.containsKey(child)) {
- alias = aliasCache.get(child);
- } else {
- alias = new Alias(child);
- aliasCache.put(child, alias);
- }
- bottomProjects.add(alias);
- newChildren.add(alias.toSlot());
- }
- }
- AggregateFunction newDistinctAggFunc =
distinctAggFunc.withChildren(newChildren);
- replaceMap.put(distinctAggFunc, newDistinctAggFunc);
- newDistinctAggFuncs.add(newDistinctAggFunc);
- }
- aggregateOutput = aggregateOutput.stream()
- .map(e -> ExpressionUtils.replace(e, replaceMap))
- .map(NamedExpression.class::cast)
- .collect(Collectors.toList());
- distinctAggFuncs = newDistinctAggFuncs;
- }
- normalizedAggFuncs = Lists.newArrayList(nonDistinctAggFuncs);
- normalizedAggFuncs.addAll(distinctAggFuncs);
- // TODO: process redundant expressions in aggregate functions
children
+
+ // normalize group by exprs by bottomProjects
+ List<Expression> normalizedGroupExprs =
+ bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);
+
+ // normalize trival-aggs by bottomProjects
+ List<AggregateFunction> normalizedAggFuncs =
+ bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
+
// build normalized agg output
NormalizeToSlotContext normalizedAggFuncsToSlotContext =
NormalizeToSlotContext.buildContext(existsAlias,
normalizedAggFuncs);
- // agg output include 2 part, normalized group by slots and
normalized agg functions
+
+ // agg output include 2 parts
+ // pushedGroupByExprs and normalized agg functions
List<NamedExpression> normalizedAggOutput =
ImmutableList.<NamedExpression>builder()
-
.addAll(bottomOutputs.stream().map(NamedExpression::toSlot).iterator())
-
.addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs))
+
.addAll(pushedGroupByExprs.stream().map(NamedExpression::toSlot).iterator())
+ .addAll(normalizedAggFuncsToSlotContext
+ .pushDownToNamedExpression(normalizedAggFuncs))
.build();
- // add normalized agg's input slots to bottom projects
- Set<Slot> bottomProjectSlots = bottomProjects.stream()
- .map(NamedExpression::toSlot)
- .collect(Collectors.toSet());
- Set<NamedExpression> aggInputSlots = normalizedAggFuncs.stream()
- .map(Expression::getInputSlots)
- .flatMap(Set::stream)
- .filter(e -> !bottomProjectSlots.contains(e))
- .collect(Collectors.toSet());
- bottomProjects.addAll(aggInputSlots);
- // build group by exprs
- List<Expression> normalizedGroupExprs =
bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);
- Plan bottomPlan;
- if (!bottomProjects.isEmpty()) {
- bottomPlan = new
LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child());
- } else {
- bottomPlan = aggregate.child();
- }
+ // create new agg node
+ LogicalAggregate newAggregate =
+ aggregate.withNormalized(normalizedGroupExprs,
normalizedAggOutput, bottomPlan);
+ // create upper projects by normalize all output exprs in old
LogicalAggregate
List<NamedExpression> upperProjects =
normalizeOutput(aggregateOutput,
bottomSlotContext, normalizedAggFuncsToSlotContext);
- return new LogicalProject<>(upperProjects,
- aggregate.withNormalized(normalizedGroupExprs,
normalizedAggOutput, bottomPlan));
+ // create a parent project node
+ return new LogicalProject<>(upperProjects, newAggregate);
}).toRule(RuleType.NORMALIZE_AGGREGATE);
}
diff --git
a/regression-test/suites/nereids_p0/aggregate/agg_distinct_case_when.groovy
b/regression-test/suites/nereids_p0/aggregate/agg_distinct_case_when.groovy
new file mode 100644
index 00000000000..74caa459c15
--- /dev/null
+++ b/regression-test/suites/nereids_p0/aggregate/agg_distinct_case_when.groovy
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+suite("agg_distinct_case_when") {
+ sql "SET enable_nereids_planner=true"
+ sql "SET enable_fallback_to_original_planner=false"
+ sql "DROP TABLE IF EXISTS agg_test_table_t;"
+ sql """
+ CREATE TABLE `agg_test_table_t` (
+ `k1` varchar(65533) NULL,
+ `k2` text NULL,
+ `k3` text null,
+ `k4` text null
+ ) ENGINE=OLAP
+ DUPLICATE KEY(`k1`)
+ COMMENT 'OLAP'
+ DISTRIBUTED BY HASH(`k1`) BUCKETS 10
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1",
+ "is_being_synced" = "false",
+ "storage_format" = "V2",
+ "light_schema_change" = "true",
+ "disable_auto_compaction" = "false",
+ "enable_single_replica_compaction" = "false"
+ );
+ """
+
+ sql """insert into agg_test_table_t(`k1`,`k2`,`k3`)
values('20231026221524','PA','adigu1bububud');"""
+ sql """
+ select
+ count(distinct case when t.k2='PA' and
loan_date=to_date(substr(t.k1,1,8)) then t.k2 end )
+ from (
+ select substr(k1,1,8) loan_date,k3,k2,k1 from agg_test_table_t) t
+ group by
+ substr(t.k1,1,8);"""
+
+ sql "DROP TABLE IF EXISTS agg_test_table_t;"
+}
\ No newline at end of file
diff --git
a/regression-test/suites/query_p0/sql_functions/window_functions/test_window_fn.groovy
b/regression-test/suites/query_p0/sql_functions/window_functions/test_window_fn.groovy
index 22a9e798f0a..e8604debe45 100644
---
a/regression-test/suites/query_p0/sql_functions/window_functions/test_window_fn.groovy
+++
b/regression-test/suites/query_p0/sql_functions/window_functions/test_window_fn.groovy
@@ -383,10 +383,8 @@ suite("test_window_fn") {
"storage_format" = "V2"
);
"""
- test {
- sql """SELECT SUM(MAX(c1) OVER (PARTITION BY c2, c3)) FROM
test_window_in_agg;"""
- exception "errCode = 2, detailMessage = AGGREGATE clause must not
contain analytic expressions"
- }
+ sql """set enable_nereids_planner=true;"""
+ sql """SELECT SUM(MAX(c1) OVER (PARTITION BY c2, c3)) FROM
test_window_in_agg;"""
sql "DROP TABLE IF EXISTS test_window_in_agg;"
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]