This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new ebf474d9d89 [feature](nereids) deal the slots that appear both in agg
func and grouping sets (#31318)
ebf474d9d89 is described below
commit ebf474d9d89cbca6728076ec1a27afbc0e51908f
Author: feiniaofeiafei <[email protected]>
AuthorDate: Mon Feb 26 19:59:51 2024 +0800
[feature](nereids) deal the slots that appear both in agg func and grouping
sets (#31318)
this PR support slot appearing both in agg func and grouping sets.
sql like below:
select sum(a) from t group by grouping sets ((a));
Before this PR, Nereids throw exception like below:
col_int_undef_signed cannot both in select list and aggregate functions
when using GROUPING SETS/CUBE/ROLLUP, please use union instead.
This PR removes the restriction and supports this situation.
---
.../nereids/rules/analysis/NormalizeRepeat.java | 100 ++++++++++++++++-----
.../grouping_sets/test_grouping_sets.out | 26 ++++++
...ot_both_appear_in_agg_fun_and_grouping_sets.out | 66 ++++++++++++++
.../query_p0/grouping_sets/test_grouping_sets.out | 5 ++
.../grouping_sets/test_grouping_sets.groovy | 26 ++----
...both_appear_in_agg_fun_and_grouping_sets.groovy | 62 +++++++++++++
.../suites/nereids_syntax_p0/grouping_sets.groovy | 16 ----
.../grouping_sets/test_grouping_sets.groovy | 27 +-----
8 files changed, 248 insertions(+), 80 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
index 005cc663862..9326ee725ff 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
@@ -23,7 +23,6 @@ import org.apache.doris.nereids.rules.RuleType;
import
org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
import
org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotTriplet;
import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
@@ -44,8 +43,10 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.collect.Sets.SetView;
+import org.jetbrains.annotations.NotNull;
import java.util.Collection;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -80,35 +81,16 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory
{
logicalRepeat(any()).when(LogicalRepeat::canBindVirtualSlot).then(repeat -> {
checkRepeatLegality(repeat);
// add virtual slot, LogicalAggregate and LogicalProject for
normalize
- return normalizeRepeat(repeat);
+ LogicalAggregate<Plan> agg = normalizeRepeat(repeat);
+ return dealSlotAppearBothInAggFuncAndGroupingSets(agg);
})
);
}
private void checkRepeatLegality(LogicalRepeat<Plan> repeat) {
- checkIfAggFuncSlotInGroupingSets(repeat);
checkGroupingSetsSize(repeat);
}
- private void checkIfAggFuncSlotInGroupingSets(LogicalRepeat<Plan> repeat) {
- Set<Slot> aggUsedSlots = repeat.getOutputExpressions().stream()
- .flatMap(e ->
e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
- .flatMap(e ->
e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
- .collect(ImmutableSet.toImmutableSet());
- Set<ExprId> groupingSetsUsedSlotExprIds =
repeat.getGroupingSets().stream()
- .flatMap(Collection::stream)
- .flatMap(e ->
e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
- .map(SlotReference::getExprId)
- .collect(Collectors.toSet());
- for (Slot slot : aggUsedSlots) {
- if (groupingSetsUsedSlotExprIds.contains(slot.getExprId())) {
- throw new AnalysisException("column: " + slot.toSql() + "
cannot both in select "
- + "list and aggregate functions when using GROUPING
SETS/CUBE/ROLLUP, "
- + "please use union instead.");
- }
- }
- }
-
private void checkGroupingSetsSize(LogicalRepeat<Plan> repeat) {
Set<Expression> flattenGroupingSetExpr = ImmutableSet.copyOf(
ExpressionUtils.flatExpressions(repeat.getGroupingSets()));
@@ -265,4 +247,78 @@ public class NormalizeRepeat extends
OneAnalysisRuleFactory {
return expr;
}
}
+
+ /*
+ * compute slots that appear both in agg func and grouping sets,
+ * copy the slots and output in the project below the repeat as new copied
slots,
+ * and refer the new copied slots in aggregate parameters.
+ * eg: original plan after normalizedRepeat
+ * LogicalAggregate (groupByExpr=[a#0, GROUPING_ID#1], outputExpr=[a#0,
GROUPING_ID#1, sum(a#0) as `sum(a)`#2])
+ * +--LogicalRepeat (groupingSets=[[a#0]], outputExpr=[a#0,
GROUPING_ID#1]
+ * +--LogicalProject (projects =[a#0])
+ * After:
+ * LogicalAggregate (groupByExpr=[a#0, GROUPING_ID#1], outputExpr=[a#0,
GROUPING_ID#1, sum(a#3) as `sum(a)`#2])
+ * +--LogicalRepeat (groupingSets=[[a#0]], outputExpr=[a#0, a#3,
GROUPING_ID#1]
+ * +--LogicalProject (projects =[a#0, a#0 as `a`#3])
+ */
+ private LogicalAggregate<Plan> dealSlotAppearBothInAggFuncAndGroupingSets(
+ @NotNull LogicalAggregate<Plan> aggregate) {
+ LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+ Set<Slot> aggUsedSlots = aggregate.getOutputExpressions().stream()
+ .flatMap(e ->
e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
+ .flatMap(e ->
e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
+ .collect(ImmutableSet.toImmutableSet());
+ Set<Slot> groupingSetsUsedSlot = repeat.getGroupingSets().stream()
+ .flatMap(Collection::stream)
+ .flatMap(e ->
e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
+ .collect(Collectors.toSet());
+
+ Set<Slot> resSet = new HashSet<>(aggUsedSlots);
+ resSet.retainAll(groupingSetsUsedSlot);
+ if (resSet.isEmpty()) {
+ return aggregate;
+ }
+ Map<Slot, Alias> slotMapping = resSet.stream().collect(
+ Collectors.toMap(key -> key, Alias::new)
+ );
+ Set<Alias> newAliases = new HashSet<>(slotMapping.values());
+ List<Slot> newSlots = newAliases.stream()
+ .map(Alias::toSlot)
+ .collect(Collectors.toList());
+
+ // modify repeat child to a new project with more projections
+ List<Slot> originSlots = repeat.child().getOutput();
+ ImmutableList<NamedExpression> immList =
+
ImmutableList.<NamedExpression>builder().addAll(originSlots).addAll(newAliases).build();
+ LogicalProject<Plan> newProject = new LogicalProject<>(immList,
repeat.child());
+ repeat = repeat.withChildren(ImmutableList.of(newProject));
+
+ // modify repeat outputs
+ List<Slot> originRepeatSlots = repeat.getOutput();
+ repeat = repeat.withAggOutput(ImmutableList
+ .<NamedExpression>builder()
+ .addAll(originRepeatSlots.stream().filter(slot -> ! (slot
instanceof VirtualSlotReference))
+ .collect(Collectors.toList()))
+ .addAll(newSlots)
+ .addAll(originRepeatSlots.stream().filter(slot -> (slot
instanceof VirtualSlotReference))
+ .collect(Collectors.toList()))
+ .build());
+ aggregate = aggregate.withChildren(ImmutableList.of(repeat));
+
+ // modify aggregate functions' parameter slot reference to new copied
slots
+ List<NamedExpression> newOutputExpressions =
aggregate.getOutputExpressions().stream()
+ .map(output -> (NamedExpression)
output.rewriteDownShortCircuit(expr -> {
+ if (expr instanceof AggregateFunction) {
+ return expr.rewriteDownShortCircuit(e -> {
+ if (e instanceof Slot &&
slotMapping.containsKey(e)) {
+ return slotMapping.get(e).toSlot();
+ }
+ return e;
+ });
+ }
+ return expr;
+ })
+ ).collect(Collectors.toList());
+ return aggregate.withAggOutput(newOutputExpressions);
+ }
}
diff --git
a/regression-test/data/nereids_p0/grouping_sets/test_grouping_sets.out
b/regression-test/data/nereids_p0/grouping_sets/test_grouping_sets.out
index f2da1d2f673..67d76e45936 100644
--- a/regression-test/data/nereids_p0/grouping_sets/test_grouping_sets.out
+++ b/regression-test/data/nereids_p0/grouping_sets/test_grouping_sets.out
@@ -48,4 +48,30 @@
2 10 1991
-- !select7 --
+\N \N 1002
+\N \N 2002
+\N \N 3004
+\N 1986 1001
+\N 1989 2003
+1 \N 1001
+1 1989 1001
+2 \N 1001
+2 1986 1001
+3 \N 1002
+3 1989 1002
+
+-- !select8 --
+\N \N 0.9990029910269193
+\N \N 0.9995007488766849
+\N \N 0.9996672212978369
+\N 1986 0.999001996007984
+\N 1989 0.9995009980039921
+1 \N 0.999001996007984
+1 1989 0.999001996007984
+2 \N 0.999001996007984
+2 1986 0.999001996007984
+3 \N 0.9990029910269193
+3 1989 0.9990029910269193
+
+-- !select9 --
diff --git
a/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out
b/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out
new file mode 100644
index 00000000000..901226f8548
--- /dev/null
+++
b/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out
@@ -0,0 +1,66 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !select1 --
+\N
+\N
+-48
+-48
+-43
+-43
+-43
+-12
+82
+82
+89
+89
+
+-- !select2 --
+\N
+\N
+-46
+-46
+-39
+-39
+-38
+-11
+91
+91
+97
+97
+
+-- !select3 --
+\N
+\N
+\N
+-47
+-47
+-47
+-42
+-42
+-42
+-42
+-11
+83
+83
+90
+90
+16055
+19197
+
+-- !select4 --
+\N
+a
+how
+j
+say
+yeah
+
+-- !select5 --
+1
+1
+1
+2
+3
+3
+4
+5
+
diff --git a/regression-test/data/query_p0/grouping_sets/test_grouping_sets.out
b/regression-test/data/query_p0/grouping_sets/test_grouping_sets.out
index b3d3050ee77..052d4e1c35d 100644
--- a/regression-test/data/query_p0/grouping_sets/test_grouping_sets.out
+++ b/regression-test/data/query_p0/grouping_sets/test_grouping_sets.out
@@ -203,3 +203,8 @@ test 2
1989-03-21 \N 1001 0 1 1
2012-03-14 \N 1002 0 1 1
+-- !select24 --
+1 0
+2 0
+3 0
+
diff --git
a/regression-test/suites/nereids_p0/grouping_sets/test_grouping_sets.groovy
b/regression-test/suites/nereids_p0/grouping_sets/test_grouping_sets.groovy
index b5671a77a56..79a193c95e2 100644
--- a/regression-test/suites/nereids_p0/grouping_sets/test_grouping_sets.groovy
+++ b/regression-test/suites/nereids_p0/grouping_sets/test_grouping_sets.groovy
@@ -45,27 +45,15 @@ suite("test_grouping_sets") {
group by grouping sets((k_if, k1),()) order by k_if, k1,
k2_sum
"""
- test {
- sql """
- SELECT k1, k2, SUM(k3) FROM nereids_test_query_db.test
- GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order
by k1, k2
+ qt_select7 """
+ SELECT k1, k2, SUM(k3) k3_ FROM nereids_test_query_db.test
+ GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order
by k1, k2, k3_
"""
- check{result, exception, startTime, endTime ->
- assertTrue(exception != null)
- logger.info(exception.message)
- }
- }
- test {
- sql """
- SELECT k1, k2, SUM(k3)/(SUM(k3)+1) FROM
nereids_test_query_db.test
- GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order
by k1, k2
+ qt_select8 """
+ SELECT k1, k2, SUM(k3)/(SUM(k3)+1) k3_ FROM
nereids_test_query_db.test
+ GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order
by k1, k2, k3_
"""
- check{result, exception, startTime, endTime ->
- assertTrue(exception != null)
- logger.info(exception.message)
- }
- }
- qt_select7 """ select k1,k2,sum(k3) from nereids_test_query_db.test where 1
= 2 group by grouping sets((k1), (k1,k2)) """
+ qt_select9 """ select k1,k2,sum(k3) from nereids_test_query_db.test where
1 = 2 group by grouping sets((k1), (k1,k2)) """
}
diff --git
a/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy
b/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy
new file mode 100644
index 00000000000..ac711cf5aab
--- /dev/null
+++
b/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+suite("slot_both_appear_in_agg_fun_and_grouping_sets") {
+
+ sql """
+ DROP TABLE IF EXISTS table_10_undef_undef4
+ """
+
+ sql """
+ create table table_10_undef_undef4 (`pk` int,`col_int_undef_signed`
int ,
+ `col_text_undef_signed` text ) engine=olap distributed by hash(pk)
buckets 10
+ properties( 'replication_num' = '1');
+ """
+
+ sql """
+ insert into table_10_undef_undef4 values (0,16054,null),(1,-12,null),
+
(2,-48,'j'),(3,null,null),(4,-43,"say"),(5,-43,null),(6,null,'a'),(7,19196,null),
+ (8,89,"how"),(9,82,"yeah");
+
+ """
+
+ qt_select1 """
+ SELECT MIN(`col_int_undef_signed`) FROM table_10_undef_undef4 AS T1
GROUP BY
+ GROUPING SETS((`col_int_undef_signed`,`col_text_undef_signed`),
(`col_text_undef_signed`), ())
+ HAVING T1.`col_int_undef_signed` < 3 OR T1.col_text_undef_signed >
'' order by 1;
+ """
+
+ qt_select2 """
+ SELECT MIN(col_int_undef_signed+pk) FROM table_10_undef_undef4 AS T1
GROUP BY
+ GROUPING SETS((col_int_undef_signed,col_text_undef_signed),
+ (col_text_undef_signed), (pk),()) HAVING T1.col_int_undef_signed < 3
OR T1.col_text_undef_signed > '' order by 1;
+ """
+
+ qt_select3 """
+ SELECT MIN(col_int_undef_signed+1) FROM table_10_undef_undef4 AS T1
GROUP BY
+ GROUPING SETS((col_int_undef_signed+1,col_text_undef_signed),
(col_text_undef_signed), ()) order by 1;
+ """
+
+ qt_select4 """
+ select group_concat(col_text_undef_signed,',' ) from
table_10_undef_undef4
+ group by grouping sets((col_text_undef_signed)) order by 1;
+ """
+
+ qt_select5 """
+ select sum(rank() over (partition by col_text_undef_signed order by
col_int_undef_signed))
+ as col1 from table_10_undef_undef4 group by grouping
sets((col_int_undef_signed)) order by 1;
+ """
+}
diff --git a/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy
b/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy
index 0845d705e86..8ca787fabfb 100644
--- a/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy
+++ b/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy
@@ -138,22 +138,6 @@ suite("test_nereids_grouping_sets") {
group by grouping sets((k_if, k1),()) order by k_if, k1,
k2_sum
"""
- test {
- sql """
- SELECT k1, k2, SUM(k3) FROM groupingSetsTable
- GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order
by k1, k2
- """
- exception "java.sql.SQLException: errCode = 2, detailMessage = column:
k3 cannot both in select list and aggregate functions when using GROUPING
SETS/CUBE/ROLLUP, please use union instead."
- }
-
- test {
- sql """
- SELECT k1, k2, SUM(k3)/(SUM(k3)+1) FROM groupingSetsTable
- GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order
by k1, k2
- """
- exception "java.sql.SQLException: errCode = 2, detailMessage = column:
k3 cannot both in select list and aggregate functions when using GROUPING
SETS/CUBE/ROLLUP, please use union instead."
- }
-
order_qt_select """
select k1, sum(k2) from (select k1, k2, grouping(k1), grouping(k2)
from groupingSetsTableNotNullable group by grouping sets((k1), (k2)))a group by
k1
"""
diff --git
a/regression-test/suites/query_p0/grouping_sets/test_grouping_sets.groovy
b/regression-test/suites/query_p0/grouping_sets/test_grouping_sets.groovy
index 6564bca3509..c56ba366bbb 100644
--- a/regression-test/suites/query_p0/grouping_sets/test_grouping_sets.groovy
+++ b/regression-test/suites/query_p0/grouping_sets/test_grouping_sets.groovy
@@ -52,15 +52,6 @@ suite("test_grouping_sets", "p0") {
exception "errCode = 2, detailMessage = column: `k3` cannot both in
select list and aggregate functions"
}
- sql """set enable_nereids_planner=true;"""
- sql """set enable_fallback_to_original_planner=false;"""
- test {
- sql """
- SELECT k1, k2, SUM(k3) FROM test_query_db.test
- GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order
by k1, k2
- """
- exception "errCode = 2, detailMessage = column: k3 cannot both in
select list and aggregate functions"
- }
sql """set enable_nereids_planner=false;"""
sql """set enable_fallback_to_original_planner=true;"""
test {
@@ -71,15 +62,6 @@ suite("test_grouping_sets", "p0") {
exception "errCode = 2, detailMessage = column: `k3` cannot both in
select list and aggregate functions"
}
- sql """set enable_nereids_planner=true;"""
- sql """set enable_fallback_to_original_planner=false;"""
- test {
- sql """
- SELECT k1, k2, SUM(k3)/(SUM(k3)+1) FROM test_query_db.test
- GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order
by k1, k2
- """
- exception "errCode = 2, detailMessage = column: k3 cannot both in
select list and aggregate functions"
- }
sql """set enable_nereids_planner=false;"""
sql """set enable_fallback_to_original_planner=true;"""
@@ -269,9 +251,8 @@ suite("test_grouping_sets", "p0") {
sql """set enable_nereids_planner=true;"""
sql """set enable_fallback_to_original_planner=false;"""
- test {
- sql "select k1, if(grouping(k1)=1, count(k1), 0) from
test_query_db.test group by grouping sets((k1))"
- exception "k1 cannot both in select list and aggregate functions " +
- "when using GROUPING SETS/CUBE/ROLLUP, please use union
instead."
- }
+ qt_select24 """
+ select k1, if(grouping(k1)=1, count(k1), 0) from test_query_db.test
group by grouping sets((k1))
+ order by 1,2
+ """
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]