This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new a9027169577 [enhancement](nereids) only push having as agg's parent if 
having just use slots from agg's output (#32847)
a9027169577 is described below

commit a9027169577854d18aa1476917bb1a7183a5e3e3
Author: starocean999 <[email protected]>
AuthorDate: Tue Mar 26 19:28:08 2024 +0800

    [enhancement](nereids) only push having as agg's parent if having just use 
slots from agg's output (#32847)
    
    pick from master #32414
---
 .../nereids/rules/analysis/NormalizeAggregate.java |  50 ++++++++--
 .../nereids/rules/analysis/AnalyzeCTETest.java     |   2 +-
 .../rules/analysis/FillUpMissingSlotsTest.java     | 101 +++++++++------------
 .../nereids_p0/aggregate/agg_window_project.out    |   4 +
 .../nereids_p0/aggregate/agg_error_msg.groovy      |  50 ++++++++++
 .../nereids_p0/aggregate/agg_window_project.groovy |   2 +
 6 files changed, 146 insertions(+), 63 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
index 0176064a547..3f5f749e2fc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
@@ -24,6 +24,7 @@ import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
 import 
org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
 import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
 import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
@@ -47,6 +48,7 @@ import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 
+import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -252,13 +254,38 @@ public class NormalizeAggregate implements 
RewriteRuleFactory, NormalizeToSlot {
 
         // create a parent project node
         LogicalProject<Plan> project = new LogicalProject(upperProjects, 
newAggregate);
+        // verify project used slots are all coming from agg's output
+        List<Slot> slots = collectAllUsedSlots(upperProjects);
+        if (!slots.isEmpty()) {
+            Set<ExprId> aggOutputExprIds = new HashSet<>(slots.size());
+            for (NamedExpression expression : normalizedAggOutput) {
+                aggOutputExprIds.add(expression.getExprId());
+            }
+            List<Slot> errorSlots = new ArrayList<>(slots.size());
+            for (Slot slot : slots) {
+                if (!aggOutputExprIds.contains(slot.getExprId())) {
+                    errorSlots.add(slot);
+                }
+            }
+            if (!errorSlots.isEmpty()) {
+                throw new AnalysisException(String.format("%s not in 
aggregate's output", errorSlots
+                        
.stream().map(NamedExpression::getName).collect(Collectors.joining(", "))));
+            }
+        }
         if (having != null) {
-            if (upperProjects.stream().anyMatch(expr -> 
expr.anyMatch(WindowExpression.class::isInstance))) {
-                // when project contains window functions, in order to get the 
correct result
-                // push having through project to make it the parent node of 
logicalAgg
-                return project.withChildren(ImmutableList.of(new LogicalHaving(
-                                        
ExpressionUtils.replace(having.getConjuncts(), project.getAliasToProducer()),
-                                        project.child())));
+            Set<Slot> havingUsedSlots = 
ExpressionUtils.getInputSlotSet(having.getExpressions());
+            Set<ExprId> havingUsedExprIds = new 
HashSet<>(havingUsedSlots.size());
+            for (Slot slot : havingUsedSlots) {
+                havingUsedExprIds.add(slot.getExprId());
+            }
+            Set<ExprId> aggOutputExprIds = newAggregate.getOutputExprIdSet();
+            if (aggOutputExprIds.containsAll(havingUsedExprIds)) {
+                // when having just use output slots from agg, we push down 
having as parent of agg
+                return project.withChildren(ImmutableList.of(
+                        new LogicalHaving<>(
+                                ExpressionUtils.replace(having.getConjuncts(), 
project.getAliasToProducer()),
+                                project.child()
+                        )));
             } else {
                 return (LogicalPlan) having.withChildren(project);
             }
@@ -293,4 +320,15 @@ public class NormalizeAggregate implements 
RewriteRuleFactory, NormalizeToSlot {
         }
         return builder.build();
     }
+
+    private List<Slot> collectAllUsedSlots(List<NamedExpression> expressions) {
+        Set<Slot> inputSlots = ExpressionUtils.getInputSlotSet(expressions);
+        List<SubqueryExpr> subqueries = 
ExpressionUtils.collectAll(expressions, SubqueryExpr.class::isInstance);
+        List<Slot> slots = new ArrayList<>(inputSlots.size() + 
subqueries.size());
+        for (SubqueryExpr subqueryExpr : subqueries) {
+            slots.addAll(subqueryExpr.getCorrelateSlots());
+        }
+        slots.addAll(ExpressionUtils.getInputSlotSet(expressions));
+        return slots;
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java
index ef5a32e2d3b..522f198e3ff 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java
@@ -140,7 +140,7 @@ public class AnalyzeCTETest extends TestWithFeService 
implements MemoPatternMatc
                         logicalFilter(
                                 logicalProject(
                                         logicalJoin(
-                                                
logicalProject(logicalAggregate()),
+                                                logicalAggregate(),
                                                 logicalProject()
                                         )
                                 )
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java
index d6c71ee7759..534833754bd 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java
@@ -87,10 +87,9 @@ public class FillUpMissingSlotsTest extends 
AnalyzeCheckTestBase implements Memo
         PlanChecker.from(connectContext).analyze(sql)
                 .matches(
                         logicalFilter(
-                                logicalProject(
-                                        logicalAggregate(
-                                                
logicalProject(logicalOlapScan())
-                                        
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1))))));
+                                    logicalAggregate(
+                                            logicalProject(logicalOlapScan())
+                                    
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))));
 
         sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING a1 > 0";
         SlotReference value = new SlotReference(new ExprId(3), "value", 
TinyIntType.INSTANCE, true,
@@ -134,10 +133,9 @@ public class FillUpMissingSlotsTest extends 
AnalyzeCheckTestBase implements Memo
                 .matches(
                         logicalProject(
                                 logicalFilter(
-                                        logicalProject(
-                                                logicalAggregate(
-                                                        
logicalProject(logicalOlapScan())
-                                                
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
+                                            logicalAggregate(
+                                                    
logicalProject(logicalOlapScan())
+                                            
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
                                 ).when(FieldChecker.check("conjuncts", 
ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0)))))
                         ).when(FieldChecker.check("projects", 
Lists.newArrayList(sumA2.toSlot()))));
     }
@@ -158,10 +156,9 @@ public class FillUpMissingSlotsTest extends 
AnalyzeCheckTestBase implements Memo
                 .matches(
                         logicalProject(
                                 logicalFilter(
-                                        logicalProject(
-                                                logicalAggregate(
-                                                        
logicalProject(logicalOlapScan())
-                                                
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
+                                            logicalAggregate(
+                                                    
logicalProject(logicalOlapScan())
+                                            
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
                                 ).when(FieldChecker.check("conjuncts", 
ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))
                         ).when(FieldChecker.check("projects", 
Lists.newArrayList(a1.toSlot()))));
 
@@ -171,13 +168,12 @@ public class FillUpMissingSlotsTest extends 
AnalyzeCheckTestBase implements Memo
                 .matches(
                         logicalProject(
                                 logicalFilter(
-                                        logicalProject(
-                                                logicalAggregate(
-                                                        logicalProject(
-                                                                
logicalOlapScan()
-                                                        )
-                                                
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
-                                ).when(FieldChecker.check("conjuncts", 
ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))));
+                                            logicalAggregate(
+                                                    logicalProject(
+                                                            logicalOlapScan()
+                                                    )
+                                            
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
+                                .when(FieldChecker.check("conjuncts", 
ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))));
 
         sql = "SELECT a1, sum(a2) as value FROM t1 GROUP BY a1 HAVING sum(a2) 
> 0";
         a1 = new SlotReference(
@@ -193,22 +189,20 @@ public class FillUpMissingSlotsTest extends 
AnalyzeCheckTestBase implements Memo
                 .matches(
                         logicalProject(
                                 logicalFilter(
-                                        logicalProject(
-                                                logicalAggregate(
-                                                        logicalProject(
-                                                                
logicalOlapScan())
-                                                
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))))
+                                            logicalAggregate(
+                                                    logicalProject(
+                                                            logicalOlapScan())
+                                            
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
                                 ).when(FieldChecker.check("conjuncts", 
ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L)))))));
 
         sql = "SELECT a1, sum(a2) as value FROM t1 GROUP BY a1 HAVING value > 
0";
         PlanChecker.from(connectContext).analyze(sql)
                 .matches(
                         logicalFilter(
-                                logicalProject(
-                                        logicalAggregate(
-                                                logicalProject(
-                                                        logicalOlapScan())
-                                        
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))))
+                                    logicalAggregate(
+                                            logicalProject(
+                                                    logicalOlapScan())
+                                    
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
                         ).when(FieldChecker.check("conjuncts", 
ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))));
 
         sql = "SELECT a1, sum(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0";
@@ -230,10 +224,9 @@ public class FillUpMissingSlotsTest extends 
AnalyzeCheckTestBase implements Memo
                 .matches(
                         logicalProject(
                                 logicalFilter(
-                                        logicalProject(
-                                                logicalAggregate(
-                                                        
logicalProject(logicalOlapScan())
-                                                
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, 
minPK))))
+                                            logicalAggregate(
+                                                    
logicalProject(logicalOlapScan())
+                                            
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, 
minPK)))
                                 ).when(FieldChecker.check("conjuncts", 
ImmutableSet.of(new GreaterThan(minPK.toSlot(), Literal.of((byte) 0)))))
                         ).when(FieldChecker.check("projects", 
Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
 
@@ -243,10 +236,9 @@ public class FillUpMissingSlotsTest extends 
AnalyzeCheckTestBase implements Memo
                 .matches(
                         logicalProject(
                                 logicalFilter(
-                                        logicalProject(
-                                                logicalAggregate(
-                                                        
logicalProject(logicalOlapScan())
-                                                
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, 
sumA1A2))))
+                                            logicalAggregate(
+                                                    
logicalProject(logicalOlapScan())
+                                            
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))
                                 ).when(FieldChecker.check("conjuncts", 
ImmutableSet.of(new GreaterThan(sumA1A2.toSlot(), Literal.of(0L)))))));
 
         sql = "SELECT a1, sum(a1 + a2) FROM t1 GROUP BY a1 HAVING sum(a1 + a2 
+ 3) > 0";
@@ -256,10 +248,9 @@ public class FillUpMissingSlotsTest extends 
AnalyzeCheckTestBase implements Memo
                 .matches(
                         logicalProject(
                                 logicalFilter(
-                                        logicalProject(
-                                                logicalAggregate(
-                                                        
logicalProject(logicalOlapScan())
-                                                
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, 
sumA1A23))))
+                                            logicalAggregate(
+                                                    
logicalProject(logicalOlapScan())
+                                            
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, 
sumA1A23)))
                                 ).when(FieldChecker.check("conjuncts", 
ImmutableSet.of(new GreaterThan(sumA1A23.toSlot(), Literal.of(0L)))))
                         ).when(FieldChecker.check("projects", 
Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot()))));
 
@@ -269,10 +260,9 @@ public class FillUpMissingSlotsTest extends 
AnalyzeCheckTestBase implements Memo
                 .matches(
                         logicalProject(
                                 logicalFilter(
-                                        logicalProject(
-                                                logicalAggregate(
-                                                        
logicalProject(logicalOlapScan())
-                                                
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, 
countStar))))
+                                            logicalAggregate(
+                                                    
logicalProject(logicalOlapScan())
+                                            
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, 
countStar)))
                                 ).when(FieldChecker.check("conjuncts", 
ImmutableSet.of(new GreaterThan(countStar.toSlot(), Literal.of(0L)))))
                         ).when(FieldChecker.check("projects", 
Lists.newArrayList(a1.toSlot()))));
     }
@@ -298,17 +288,16 @@ public class FillUpMissingSlotsTest extends 
AnalyzeCheckTestBase implements Memo
                 .matches(
                         logicalProject(
                                 logicalFilter(
-                                        logicalProject(
-                                                logicalAggregate(
-                                                        logicalProject(
-                                                                logicalFilter(
-                                                                        
logicalJoin(
-                                                                               
 logicalOlapScan(),
-                                                                               
 logicalOlapScan()
-                                                                        )
-                                                                ))
-                                                
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, 
sumB1)))
-                                        
)).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new 
Cast(a1, BigIntType.INSTANCE),
+                                            logicalAggregate(
+                                                    logicalProject(
+                                                            logicalFilter(
+                                                                    
logicalJoin(
+                                                                            
logicalOlapScan(),
+                                                                            
logicalOlapScan()
+                                                                    )
+                                                            ))
+                                            
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, 
sumB1)))
+                                        ).when(FieldChecker.check("conjuncts", 
ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE),
                                         sumB1.toSlot()))))
                         ).when(FieldChecker.check("projects", 
Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
     }
diff --git a/regression-test/data/nereids_p0/aggregate/agg_window_project.out 
b/regression-test/data/nereids_p0/aggregate/agg_window_project.out
index dcdb0f25eca..45fde6faa92 100644
--- a/regression-test/data/nereids_p0/aggregate/agg_window_project.out
+++ b/regression-test/data/nereids_p0/aggregate/agg_window_project.out
@@ -12,3 +12,7 @@
 2      1       23.0000000000
 2      2       23.0000000000
 
+-- !select5 --
+1      1       3.0000000000
+1      2       3.0000000000
+
diff --git a/regression-test/suites/nereids_p0/aggregate/agg_error_msg.groovy 
b/regression-test/suites/nereids_p0/aggregate/agg_error_msg.groovy
new file mode 100644
index 00000000000..0b807be9d3a
--- /dev/null
+++ b/regression-test/suites/nereids_p0/aggregate/agg_error_msg.groovy
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+suite("agg_error_msg") {
+    sql "SET enable_nereids_planner=true"
+    sql "SET enable_fallback_to_original_planner=false"
+    sql "DROP TABLE IF EXISTS 
table_20_undef_partitions2_keys3_properties4_distributed_by58;"
+    sql """
+        create table 
table_20_undef_partitions2_keys3_properties4_distributed_by58 (
+        pk int,
+        col_int_undef_signed int   ,
+        col_int_undef_signed2 int   
+        ) engine=olap
+        DUPLICATE KEY(pk, col_int_undef_signed)
+        distributed by hash(pk) buckets 10
+        properties("replication_num" = "1");
+    """
+
+    sql "DROP TABLE IF EXISTS 
table_20_undef_partitions2_keys3_properties4_distributed_by53;"
+    sql """
+        create table 
table_20_undef_partitions2_keys3_properties4_distributed_by53 (
+        pk int,
+        col_int_undef_signed int   ,
+        col_int_undef_signed2 int   
+        ) engine=olap
+        DUPLICATE KEY(pk, col_int_undef_signed)
+        distributed by hash(pk) buckets 10
+        properties("replication_num" = "1");
+    """
+    test {
+        sql """SELECT col_int_undef_signed2   col_alias1, col_int_undef_signed 
 *  (SELECT  MAX (col_int_undef_signed) FROM 
table_20_undef_partitions2_keys3_properties4_distributed_by58 where 
table_20_undef_partitions2_keys3_properties4_distributed_by53.pk = pk)  AS 
col_alias2 FROM table_20_undef_partitions2_keys3_properties4_distributed_by53  
GROUP BY  GROUPING SETS ((col_int_undef_signed2),())  ;"""
+        exception "pk, col_int_undef_signed not in aggregate's output";
+    }
+}
diff --git 
a/regression-test/suites/nereids_p0/aggregate/agg_window_project.groovy 
b/regression-test/suites/nereids_p0/aggregate/agg_window_project.groovy
index b75f46b1a06..1026201eca9 100644
--- a/regression-test/suites/nereids_p0/aggregate/agg_window_project.groovy
+++ b/regression-test/suites/nereids_p0/aggregate/agg_window_project.groovy
@@ -96,6 +96,8 @@ suite("agg_window_project") {
 
     order_qt_select4 """select a, c, sum(sum(b)) over(partition by c order by 
c rows between unbounded preceding and current row) from test_window_table2 
group by a, c having a > 1;"""
 
+    order_qt_select5 """select a, c, sum(sum(b)) over(partition by c order by 
c rows between unbounded preceding and current row) dd from test_window_table2 
group by a, c having dd < 4;"""
+    
     explain {
         sql("select a, c, sum(sum(b)) over(partition by c order by c rows 
between unbounded preceding and current row) from test_window_table2 group by 
a, c having a > 1;")
         contains "ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to