This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new 5c9f07840ad [refactor](nereids)make NormalizeAggregate rule more clear 
and readable #28607 (#28828)
5c9f07840ad is described below

commit 5c9f07840ad9ebacf803aebac3e310e79b5f7763
Author: starocean999 <[email protected]>
AuthorDate: Fri Dec 22 11:34:04 2023 +0800

    [refactor](nereids)make NormalizeAggregate rule more clear and readable 
#28607 (#28828)
---
 .../nereids/rules/analysis/NormalizeAggregate.java | 189 ++++++++++++---------
 .../aggregate/agg_distinct_case_when.groovy        |  54 ++++++
 .../window_functions/test_window_fn.groovy         |   6 +-
 3 files changed, 160 insertions(+), 89 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
index dd7f1197b11..a7eb7c7e5cc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
@@ -41,9 +41,9 @@ import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableList.Builder;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -101,23 +101,94 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
     @Override
     public Rule build() {
         return 
logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
+            // The LogicalAggregate node may contain window agg functions and 
usual agg functions
+            // we call window agg functions as window-agg and usual agg 
functions as trival-agg for short
+            // This rule simplify LogicalAggregate node by:
+            // 1. Push down some exprs from old LogicalAggregate node to a new 
child LogicalProject Node,
+            // 2. create a new LogicalAggregate with normalized group by exprs 
and trival-aggs
+            // 3. Pull up normalized old LogicalAggregate's output exprs to a 
new parent LogicalProject Node
+            // Push down exprs:
+            // 1. all group by exprs
+            // 2. child contains subquery expr in trival-agg
+            // 3. child contains window expr in trival-agg
+            // 4. all input slots of trival-agg
+            // 5. expr(including subquery) in distinct trival-agg
+            // Normalize LogicalAggregate's output.
+            // 1. normalize group by exprs by outputs of bottom LogicalProject
+            // 2. normalize trival-aggs by outputs of bottom LogicalProject
+            // 3. build normalized agg outputs
+            // Pull up exprs:
+            // normalize all output exprs in old LogicalAggregate to build a 
parent project node, typically includes:
+            // 1. simple slots
+            // 2. aliases
+            //    a. alias with no aggs child
+            //    b. alias with trival-agg child
+            //    c. alias with window-agg
 
-            List<NamedExpression> aggregateOutput = 
aggregate.getOutputExpressions();
-            Set<Alias> existsAlias = 
ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
+            // Push down exprs:
+            // collect group by exprs
+            Set<Expression> groupingByExprs =
+                    ImmutableSet.copyOf(aggregate.getGroupByExpressions());
 
+            // collect all trival-agg
+            List<NamedExpression> aggregateOutput = 
aggregate.getOutputExpressions();
             List<AggregateFunction> aggFuncs = Lists.newArrayList();
             aggregateOutput.forEach(o -> 
o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
 
-            // we need push down subquery exprs inside non-window and 
non-distinct agg functions
-            // because the distinct agg's children would be pushed down in 
later process
-            Set<SubqueryExpr> subqueryExprs = 
ExpressionUtils.mutableCollect(aggFuncs.stream()
-                    .filter(aggFunc -> 
!aggFunc.isDistinct()).collect(Collectors.toList()),
-                    SubqueryExpr.class::isInstance);
-            Set<Expression> groupingByExprs = 
ImmutableSet.copyOf(aggregate.getGroupByExpressions());
+            // split non-distinct agg child as two part
+            // TRUE part 1: need push down itself, if it contains subqury or 
window expression
+            // FALSE part 2: need push down its input slots, if it DOES NOT 
contain subqury or window expression
+            Map<Boolean, Set<Expression>> categorizedNoDistinctAggsChildren = 
aggFuncs.stream()
+                    .filter(aggFunc -> !aggFunc.isDistinct())
+                    .flatMap(agg -> agg.children().stream())
+                    .collect(Collectors.groupingBy(
+                            child -> child.containsType(SubqueryExpr.class, 
WindowExpression.class),
+                            Collectors.toSet()));
+
+            // split distinct agg child as two parts
+            // TRUE part 1: need push down itself, if it is NOT SlotReference 
or Literal
+            // FALSE part 2: need push down its input slots, if it is 
SlotReference or Literal
+            Map<Boolean, Set<Expression>> categorizedDistinctAggsChildren = 
aggFuncs.stream()
+                    .filter(aggFunc -> aggFunc.isDistinct()).flatMap(agg -> 
agg.children().stream())
+                    .collect(Collectors.groupingBy(
+                            child -> !(child instanceof SlotReference || child 
instanceof Literal),
+                            Collectors.toSet()));
+
+            Set<Expression> needPushSelf = Sets.union(
+                    categorizedNoDistinctAggsChildren.getOrDefault(true, new 
HashSet<>()),
+                    categorizedDistinctAggsChildren.getOrDefault(true, new 
HashSet<>()));
+            Set<Slot> needPushInputSlots = 
ExpressionUtils.getInputSlotSet(Sets.union(
+                    categorizedNoDistinctAggsChildren.getOrDefault(false, new 
HashSet<>()),
+                    categorizedDistinctAggsChildren.getOrDefault(false, new 
HashSet<>())));
+
+            Set<Alias> existsAlias =
+                    ExpressionUtils.mutableCollect(aggregateOutput, 
Alias.class::isInstance);
+
+            // push down 3 kinds of exprs, these pushed exprs will be used to 
normalize agg output later
+            // 1. group by exprs
+            // 2. trivalAgg children
+            // 3. trivalAgg input slots
+            Set<Expression> allPushDownExprs =
+                    Sets.union(groupingByExprs, Sets.union(needPushSelf, 
needPushInputSlots));
             NormalizeToSlotContext bottomSlotContext =
-                    NormalizeToSlotContext.buildContext(existsAlias, 
Sets.union(groupingByExprs, subqueryExprs));
-            Set<NamedExpression> bottomOutputs =
-                    
bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs, 
subqueryExprs));
+                    NormalizeToSlotContext.buildContext(existsAlias, 
allPushDownExprs);
+            Set<NamedExpression> pushedGroupByExprs =
+                    
bottomSlotContext.pushDownToNamedExpression(groupingByExprs);
+            Set<NamedExpression> pushedTrivalAggChildren =
+                    bottomSlotContext.pushDownToNamedExpression(needPushSelf);
+            Set<NamedExpression> pushedTrivalAggInputSlots =
+                    
bottomSlotContext.pushDownToNamedExpression(needPushInputSlots);
+            Set<NamedExpression> bottomProjects = 
Sets.union(pushedGroupByExprs,
+                    Sets.union(pushedTrivalAggChildren, 
pushedTrivalAggInputSlots));
+
+            // create bottom project
+            Plan bottomPlan;
+            if (!bottomProjects.isEmpty()) {
+                bottomPlan = new 
LogicalProject<>(ImmutableList.copyOf(bottomProjects),
+                        aggregate.child());
+            } else {
+                bottomPlan = aggregate.child();
+            }
 
             // use group by context to normalize agg functions to process
             //   sql like: select sum(a + 1) from t group by a + 1
@@ -129,89 +200,37 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
             // after normalize:
             // agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 
1)[#1])
             // +-- project((a[#0] + 1)[#1])
-            List<AggregateFunction> normalizedAggFuncs = 
bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
-            Set<NamedExpression> bottomProjects = 
Sets.newHashSet(bottomOutputs);
-            // TODO: if we have distinct agg, we must push down its children,
-            //   because need use it to generate distribution enforce
-            // step 1: split agg functions into 2 parts: distinct and not 
distinct
-            List<AggregateFunction> distinctAggFuncs = Lists.newArrayList();
-            List<AggregateFunction> nonDistinctAggFuncs = Lists.newArrayList();
-            for (AggregateFunction aggregateFunction : normalizedAggFuncs) {
-                if (aggregateFunction.isDistinct()) {
-                    distinctAggFuncs.add(aggregateFunction);
-                } else {
-                    nonDistinctAggFuncs.add(aggregateFunction);
-                }
-            }
-            // step 2: if we only have one distinct agg function, we do push 
down for it
-            if (!distinctAggFuncs.isEmpty()) {
-                // process distinct normalize and put it back to 
normalizedAggFuncs
-                List<AggregateFunction> newDistinctAggFuncs = 
Lists.newArrayList();
-                Map<Expression, Expression> replaceMap = Maps.newHashMap();
-                Map<Expression, NamedExpression> aliasCache = 
Maps.newHashMap();
-                for (AggregateFunction distinctAggFunc : distinctAggFuncs) {
-                    List<Expression> newChildren = Lists.newArrayList();
-                    for (Expression child : distinctAggFunc.children()) {
-                        if (child instanceof SlotReference || child instanceof 
Literal) {
-                            newChildren.add(child);
-                        } else {
-                            NamedExpression alias;
-                            if (aliasCache.containsKey(child)) {
-                                alias = aliasCache.get(child);
-                            } else {
-                                alias = new Alias(child);
-                                aliasCache.put(child, alias);
-                            }
-                            bottomProjects.add(alias);
-                            newChildren.add(alias.toSlot());
-                        }
-                    }
-                    AggregateFunction newDistinctAggFunc = 
distinctAggFunc.withChildren(newChildren);
-                    replaceMap.put(distinctAggFunc, newDistinctAggFunc);
-                    newDistinctAggFuncs.add(newDistinctAggFunc);
-                }
-                aggregateOutput = aggregateOutput.stream()
-                        .map(e -> ExpressionUtils.replace(e, replaceMap))
-                        .map(NamedExpression.class::cast)
-                        .collect(Collectors.toList());
-                distinctAggFuncs = newDistinctAggFuncs;
-            }
-            normalizedAggFuncs = Lists.newArrayList(nonDistinctAggFuncs);
-            normalizedAggFuncs.addAll(distinctAggFuncs);
-            // TODO: process redundant expressions in aggregate functions 
children
+
+            // normalize group by exprs by bottomProjects
+            List<Expression> normalizedGroupExprs =
+                    bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);
+
+            // normalize trival-aggs by bottomProjects
+            List<AggregateFunction> normalizedAggFuncs =
+                    bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
+
             // build normalized agg output
             NormalizeToSlotContext normalizedAggFuncsToSlotContext =
                     NormalizeToSlotContext.buildContext(existsAlias, 
normalizedAggFuncs);
-            // agg output include 2 part, normalized group by slots and 
normalized agg functions
+
+            // agg output include 2 parts
+            // pushedGroupByExprs and normalized agg functions
             List<NamedExpression> normalizedAggOutput = 
ImmutableList.<NamedExpression>builder()
-                    
.addAll(bottomOutputs.stream().map(NamedExpression::toSlot).iterator())
-                    
.addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs))
+                    
.addAll(pushedGroupByExprs.stream().map(NamedExpression::toSlot).iterator())
+                    .addAll(normalizedAggFuncsToSlotContext
+                            .pushDownToNamedExpression(normalizedAggFuncs))
                     .build();
-            // add normalized agg's input slots to bottom projects
-            Set<Slot> bottomProjectSlots = bottomProjects.stream()
-                    .map(NamedExpression::toSlot)
-                    .collect(Collectors.toSet());
-            Set<NamedExpression> aggInputSlots = normalizedAggFuncs.stream()
-                    .map(Expression::getInputSlots)
-                    .flatMap(Set::stream)
-                    .filter(e -> !bottomProjectSlots.contains(e))
-                    .collect(Collectors.toSet());
-            bottomProjects.addAll(aggInputSlots);
-            // build group by exprs
-            List<Expression> normalizedGroupExprs = 
bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);
 
-            Plan bottomPlan;
-            if (!bottomProjects.isEmpty()) {
-                bottomPlan = new 
LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child());
-            } else {
-                bottomPlan = aggregate.child();
-            }
+            // create new agg node
+            LogicalAggregate newAggregate =
+                    aggregate.withNormalized(normalizedGroupExprs, 
normalizedAggOutput, bottomPlan);
 
+            // create upper projects by normalize all output exprs in old 
LogicalAggregate
             List<NamedExpression> upperProjects = 
normalizeOutput(aggregateOutput,
                     bottomSlotContext, normalizedAggFuncsToSlotContext);
 
-            return new LogicalProject<>(upperProjects,
-                    aggregate.withNormalized(normalizedGroupExprs, 
normalizedAggOutput, bottomPlan));
+            // create a parent project node
+            return new LogicalProject<>(upperProjects, newAggregate);
         }).toRule(RuleType.NORMALIZE_AGGREGATE);
     }
 
diff --git 
a/regression-test/suites/nereids_p0/aggregate/agg_distinct_case_when.groovy 
b/regression-test/suites/nereids_p0/aggregate/agg_distinct_case_when.groovy
new file mode 100644
index 00000000000..74caa459c15
--- /dev/null
+++ b/regression-test/suites/nereids_p0/aggregate/agg_distinct_case_when.groovy
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+suite("agg_distinct_case_when") {
+    sql "SET enable_nereids_planner=true"
+    sql "SET enable_fallback_to_original_planner=false"
+    sql "DROP TABLE IF EXISTS agg_test_table_t;"
+    sql """
+        CREATE TABLE `agg_test_table_t` (
+        `k1` varchar(65533) NULL,
+        `k2` text NULL,
+        `k3` text null,
+        `k4` text null
+        ) ENGINE=OLAP
+        DUPLICATE KEY(`k1`)
+        COMMENT 'OLAP'
+        DISTRIBUTED BY HASH(`k1`) BUCKETS 10
+        PROPERTIES (
+        "replication_allocation" = "tag.location.default: 1",
+        "is_being_synced" = "false",
+        "storage_format" = "V2",
+        "light_schema_change" = "true",
+        "disable_auto_compaction" = "false",
+        "enable_single_replica_compaction" = "false"
+        );
+    """
+
+    sql """insert into agg_test_table_t(`k1`,`k2`,`k3`) 
values('20231026221524','PA','adigu1bububud');"""
+    sql """
+        select 
+        count(distinct case when t.k2='PA' and 
loan_date=to_date(substr(t.k1,1,8)) then t.k2 end )
+        from (
+        select substr(k1,1,8) loan_date,k3,k2,k1 from agg_test_table_t) t
+        group by
+        substr(t.k1,1,8);"""
+
+    sql "DROP TABLE IF EXISTS agg_test_table_t;"
+}
\ No newline at end of file
diff --git 
a/regression-test/suites/query_p0/sql_functions/window_functions/test_window_fn.groovy
 
b/regression-test/suites/query_p0/sql_functions/window_functions/test_window_fn.groovy
index 22a9e798f0a..e8604debe45 100644
--- 
a/regression-test/suites/query_p0/sql_functions/window_functions/test_window_fn.groovy
+++ 
b/regression-test/suites/query_p0/sql_functions/window_functions/test_window_fn.groovy
@@ -383,10 +383,8 @@ suite("test_window_fn") {
         "storage_format" = "V2"
         );
         """
-    test {
-        sql """SELECT SUM(MAX(c1) OVER (PARTITION BY c2, c3)) FROM  
test_window_in_agg;"""
-        exception "errCode = 2, detailMessage = AGGREGATE clause must not 
contain analytic expressions"
-    }
+    sql """set enable_nereids_planner=true;"""
+    sql """SELECT SUM(MAX(c1) OVER (PARTITION BY c2, c3)) FROM  
test_window_in_agg;"""
     sql "DROP TABLE IF EXISTS test_window_in_agg;"
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to