This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.0 by this push:
new a9027169577 [enhancement](nereids) only push having as agg's parent if
having just use slots from agg's output (#32847)
a9027169577 is described below
commit a9027169577854d18aa1476917bb1a7183a5e3e3
Author: starocean999 <[email protected]>
AuthorDate: Tue Mar 26 19:28:08 2024 +0800
[enhancement](nereids) only push having as agg's parent if having just use
slots from agg's output (#32847)
pick from master #32414
---
.../nereids/rules/analysis/NormalizeAggregate.java | 50 ++++++++--
.../nereids/rules/analysis/AnalyzeCTETest.java | 2 +-
.../rules/analysis/FillUpMissingSlotsTest.java | 101 +++++++++------------
.../nereids_p0/aggregate/agg_window_project.out | 4 +
.../nereids_p0/aggregate/agg_error_msg.groovy | 50 ++++++++++
.../nereids_p0/aggregate/agg_window_project.groovy | 2 +
6 files changed, 146 insertions(+), 63 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
index 0176064a547..3f5f749e2fc 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
@@ -24,6 +24,7 @@ import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
import
org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
@@ -47,6 +48,7 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
+import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -252,13 +254,38 @@ public class NormalizeAggregate implements
RewriteRuleFactory, NormalizeToSlot {
// create a parent project node
LogicalProject<Plan> project = new LogicalProject(upperProjects,
newAggregate);
+ // verify project used slots are all coming from agg's output
+ List<Slot> slots = collectAllUsedSlots(upperProjects);
+ if (!slots.isEmpty()) {
+ Set<ExprId> aggOutputExprIds = new HashSet<>(slots.size());
+ for (NamedExpression expression : normalizedAggOutput) {
+ aggOutputExprIds.add(expression.getExprId());
+ }
+ List<Slot> errorSlots = new ArrayList<>(slots.size());
+ for (Slot slot : slots) {
+ if (!aggOutputExprIds.contains(slot.getExprId())) {
+ errorSlots.add(slot);
+ }
+ }
+ if (!errorSlots.isEmpty()) {
+ throw new AnalysisException(String.format("%s not in
aggregate's output", errorSlots
+
.stream().map(NamedExpression::getName).collect(Collectors.joining(", "))));
+ }
+ }
if (having != null) {
- if (upperProjects.stream().anyMatch(expr ->
expr.anyMatch(WindowExpression.class::isInstance))) {
- // when project contains window functions, in order to get the
correct result
- // push having through project to make it the parent node of
logicalAgg
- return project.withChildren(ImmutableList.of(new LogicalHaving(
-
ExpressionUtils.replace(having.getConjuncts(), project.getAliasToProducer()),
- project.child())));
+ Set<Slot> havingUsedSlots =
ExpressionUtils.getInputSlotSet(having.getExpressions());
+ Set<ExprId> havingUsedExprIds = new
HashSet<>(havingUsedSlots.size());
+ for (Slot slot : havingUsedSlots) {
+ havingUsedExprIds.add(slot.getExprId());
+ }
+ Set<ExprId> aggOutputExprIds = newAggregate.getOutputExprIdSet();
+ if (aggOutputExprIds.containsAll(havingUsedExprIds)) {
+ // when having just use output slots from agg, we push down
having as parent of agg
+ return project.withChildren(ImmutableList.of(
+ new LogicalHaving<>(
+ ExpressionUtils.replace(having.getConjuncts(),
project.getAliasToProducer()),
+ project.child()
+ )));
} else {
return (LogicalPlan) having.withChildren(project);
}
@@ -293,4 +320,15 @@ public class NormalizeAggregate implements
RewriteRuleFactory, NormalizeToSlot {
}
return builder.build();
}
+
+ private List<Slot> collectAllUsedSlots(List<NamedExpression> expressions) {
+ Set<Slot> inputSlots = ExpressionUtils.getInputSlotSet(expressions);
+ List<SubqueryExpr> subqueries =
ExpressionUtils.collectAll(expressions, SubqueryExpr.class::isInstance);
+ List<Slot> slots = new ArrayList<>(inputSlots.size() +
subqueries.size());
+ for (SubqueryExpr subqueryExpr : subqueries) {
+ slots.addAll(subqueryExpr.getCorrelateSlots());
+ }
+ slots.addAll(ExpressionUtils.getInputSlotSet(expressions));
+ return slots;
+ }
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java
index ef5a32e2d3b..522f198e3ff 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java
@@ -140,7 +140,7 @@ public class AnalyzeCTETest extends TestWithFeService
implements MemoPatternMatc
logicalFilter(
logicalProject(
logicalJoin(
-
logicalProject(logicalAggregate()),
+ logicalAggregate(),
logicalProject()
)
)
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java
index d6c71ee7759..534833754bd 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java
@@ -87,10 +87,9 @@ public class FillUpMissingSlotsTest extends
AnalyzeCheckTestBase implements Memo
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalFilter(
- logicalProject(
- logicalAggregate(
-
logicalProject(logicalOlapScan())
-
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1))))));
+ logicalAggregate(
+ logicalProject(logicalOlapScan())
+
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))));
sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING a1 > 0";
SlotReference value = new SlotReference(new ExprId(3), "value",
TinyIntType.INSTANCE, true,
@@ -134,10 +133,9 @@ public class FillUpMissingSlotsTest extends
AnalyzeCheckTestBase implements Memo
.matches(
logicalProject(
logicalFilter(
- logicalProject(
- logicalAggregate(
-
logicalProject(logicalOlapScan())
-
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
+ logicalAggregate(
+
logicalProject(logicalOlapScan())
+
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
).when(FieldChecker.check("conjuncts",
ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0)))))
).when(FieldChecker.check("projects",
Lists.newArrayList(sumA2.toSlot()))));
}
@@ -158,10 +156,9 @@ public class FillUpMissingSlotsTest extends
AnalyzeCheckTestBase implements Memo
.matches(
logicalProject(
logicalFilter(
- logicalProject(
- logicalAggregate(
-
logicalProject(logicalOlapScan())
-
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
+ logicalAggregate(
+
logicalProject(logicalOlapScan())
+
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
).when(FieldChecker.check("conjuncts",
ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects",
Lists.newArrayList(a1.toSlot()))));
@@ -171,13 +168,12 @@ public class FillUpMissingSlotsTest extends
AnalyzeCheckTestBase implements Memo
.matches(
logicalProject(
logicalFilter(
- logicalProject(
- logicalAggregate(
- logicalProject(
-
logicalOlapScan()
- )
-
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
- ).when(FieldChecker.check("conjuncts",
ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))));
+ logicalAggregate(
+ logicalProject(
+ logicalOlapScan()
+ )
+
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
+ .when(FieldChecker.check("conjuncts",
ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))));
sql = "SELECT a1, sum(a2) as value FROM t1 GROUP BY a1 HAVING sum(a2)
> 0";
a1 = new SlotReference(
@@ -193,22 +189,20 @@ public class FillUpMissingSlotsTest extends
AnalyzeCheckTestBase implements Memo
.matches(
logicalProject(
logicalFilter(
- logicalProject(
- logicalAggregate(
- logicalProject(
-
logicalOlapScan())
-
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))))
+ logicalAggregate(
+ logicalProject(
+ logicalOlapScan())
+
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
).when(FieldChecker.check("conjuncts",
ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L)))))));
sql = "SELECT a1, sum(a2) as value FROM t1 GROUP BY a1 HAVING value >
0";
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalFilter(
- logicalProject(
- logicalAggregate(
- logicalProject(
- logicalOlapScan())
-
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))))
+ logicalAggregate(
+ logicalProject(
+ logicalOlapScan())
+
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
).when(FieldChecker.check("conjuncts",
ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))));
sql = "SELECT a1, sum(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0";
@@ -230,10 +224,9 @@ public class FillUpMissingSlotsTest extends
AnalyzeCheckTestBase implements Memo
.matches(
logicalProject(
logicalFilter(
- logicalProject(
- logicalAggregate(
-
logicalProject(logicalOlapScan())
-
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2,
minPK))))
+ logicalAggregate(
+
logicalProject(logicalOlapScan())
+
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2,
minPK)))
).when(FieldChecker.check("conjuncts",
ImmutableSet.of(new GreaterThan(minPK.toSlot(), Literal.of((byte) 0)))))
).when(FieldChecker.check("projects",
Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
@@ -243,10 +236,9 @@ public class FillUpMissingSlotsTest extends
AnalyzeCheckTestBase implements Memo
.matches(
logicalProject(
logicalFilter(
- logicalProject(
- logicalAggregate(
-
logicalProject(logicalOlapScan())
-
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1,
sumA1A2))))
+ logicalAggregate(
+
logicalProject(logicalOlapScan())
+
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))
).when(FieldChecker.check("conjuncts",
ImmutableSet.of(new GreaterThan(sumA1A2.toSlot(), Literal.of(0L)))))));
sql = "SELECT a1, sum(a1 + a2) FROM t1 GROUP BY a1 HAVING sum(a1 + a2
+ 3) > 0";
@@ -256,10 +248,9 @@ public class FillUpMissingSlotsTest extends
AnalyzeCheckTestBase implements Memo
.matches(
logicalProject(
logicalFilter(
- logicalProject(
- logicalAggregate(
-
logicalProject(logicalOlapScan())
-
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2,
sumA1A23))))
+ logicalAggregate(
+
logicalProject(logicalOlapScan())
+
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2,
sumA1A23)))
).when(FieldChecker.check("conjuncts",
ImmutableSet.of(new GreaterThan(sumA1A23.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects",
Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot()))));
@@ -269,10 +260,9 @@ public class FillUpMissingSlotsTest extends
AnalyzeCheckTestBase implements Memo
.matches(
logicalProject(
logicalFilter(
- logicalProject(
- logicalAggregate(
-
logicalProject(logicalOlapScan())
-
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1,
countStar))))
+ logicalAggregate(
+
logicalProject(logicalOlapScan())
+
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1,
countStar)))
).when(FieldChecker.check("conjuncts",
ImmutableSet.of(new GreaterThan(countStar.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects",
Lists.newArrayList(a1.toSlot()))));
}
@@ -298,17 +288,16 @@ public class FillUpMissingSlotsTest extends
AnalyzeCheckTestBase implements Memo
.matches(
logicalProject(
logicalFilter(
- logicalProject(
- logicalAggregate(
- logicalProject(
- logicalFilter(
-
logicalJoin(
-
logicalOlapScan(),
-
logicalOlapScan()
- )
- ))
-
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2,
sumB1)))
-
)).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new
Cast(a1, BigIntType.INSTANCE),
+ logicalAggregate(
+ logicalProject(
+ logicalFilter(
+
logicalJoin(
+
logicalOlapScan(),
+
logicalOlapScan()
+ )
+ ))
+
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2,
sumB1)))
+ ).when(FieldChecker.check("conjuncts",
ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE),
sumB1.toSlot()))))
).when(FieldChecker.check("projects",
Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
}
diff --git a/regression-test/data/nereids_p0/aggregate/agg_window_project.out
b/regression-test/data/nereids_p0/aggregate/agg_window_project.out
index dcdb0f25eca..45fde6faa92 100644
--- a/regression-test/data/nereids_p0/aggregate/agg_window_project.out
+++ b/regression-test/data/nereids_p0/aggregate/agg_window_project.out
@@ -12,3 +12,7 @@
2 1 23.0000000000
2 2 23.0000000000
+-- !select5 --
+1 1 3.0000000000
+1 2 3.0000000000
+
diff --git a/regression-test/suites/nereids_p0/aggregate/agg_error_msg.groovy
b/regression-test/suites/nereids_p0/aggregate/agg_error_msg.groovy
new file mode 100644
index 00000000000..0b807be9d3a
--- /dev/null
+++ b/regression-test/suites/nereids_p0/aggregate/agg_error_msg.groovy
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+suite("agg_error_msg") {
+ sql "SET enable_nereids_planner=true"
+ sql "SET enable_fallback_to_original_planner=false"
+ sql "DROP TABLE IF EXISTS
table_20_undef_partitions2_keys3_properties4_distributed_by58;"
+ sql """
+ create table
table_20_undef_partitions2_keys3_properties4_distributed_by58 (
+ pk int,
+ col_int_undef_signed int ,
+ col_int_undef_signed2 int
+ ) engine=olap
+ DUPLICATE KEY(pk, col_int_undef_signed)
+ distributed by hash(pk) buckets 10
+ properties("replication_num" = "1");
+ """
+
+ sql "DROP TABLE IF EXISTS
table_20_undef_partitions2_keys3_properties4_distributed_by53;"
+ sql """
+ create table
table_20_undef_partitions2_keys3_properties4_distributed_by53 (
+ pk int,
+ col_int_undef_signed int ,
+ col_int_undef_signed2 int
+ ) engine=olap
+ DUPLICATE KEY(pk, col_int_undef_signed)
+ distributed by hash(pk) buckets 10
+ properties("replication_num" = "1");
+ """
+ test {
+ sql """SELECT col_int_undef_signed2 col_alias1, col_int_undef_signed
* (SELECT MAX (col_int_undef_signed) FROM
table_20_undef_partitions2_keys3_properties4_distributed_by58 where
table_20_undef_partitions2_keys3_properties4_distributed_by53.pk = pk) AS
col_alias2 FROM table_20_undef_partitions2_keys3_properties4_distributed_by53
GROUP BY GROUPING SETS ((col_int_undef_signed2),()) ;"""
+ exception "pk, col_int_undef_signed not in aggregate's output";
+ }
+}
diff --git
a/regression-test/suites/nereids_p0/aggregate/agg_window_project.groovy
b/regression-test/suites/nereids_p0/aggregate/agg_window_project.groovy
index b75f46b1a06..1026201eca9 100644
--- a/regression-test/suites/nereids_p0/aggregate/agg_window_project.groovy
+++ b/regression-test/suites/nereids_p0/aggregate/agg_window_project.groovy
@@ -96,6 +96,8 @@ suite("agg_window_project") {
order_qt_select4 """select a, c, sum(sum(b)) over(partition by c order by
c rows between unbounded preceding and current row) from test_window_table2
group by a, c having a > 1;"""
+ order_qt_select5 """select a, c, sum(sum(b)) over(partition by c order by
c rows between unbounded preceding and current row) dd from test_window_table2
group by a, c having dd < 4;"""
+
explain {
sql("select a, c, sum(sum(b)) over(partition by c order by c rows
between unbounded preceding and current row) from test_window_table2 group by
a, c having a > 1;")
contains "ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]