This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.1 by this push:
new ad688171f5c [fix](Nereids): handle distinct and nullable property when
rewriting sum literal (#32778)
ad688171f5c is described below
commit ad688171f5c87b6d98b936a8138dd57f0ca2704d
Author: 谢健 <[email protected]>
AuthorDate: Tue Mar 26 20:31:33 2024 +0800
[fix](Nereids): handle distinct and nullable property when rewriting sum
literal (#32778)
Refactored the SQL query to handle the distinct and nullable values when
computing the sum of 'v' and 'v + 1' in the 't' table like:
select sum(v + 1), sum(distinct v + 1) from t
=>
select sum(v) + count(v), sum(distinct v) + count(distinct v)
---
.../nereids/rules/rewrite/SumLiteralRewrite.java | 81 ++++++++++++++++------
.../rules/rewrite/SumLiteralRewriteTest.java | 67 +++++++++++++++++-
.../data/nereids_rules_p0/sumRewrite.out | 47 +++----------
.../suites/nereids_rules_p0/sumRewrite.groovy | 37 +++++-----
4 files changed, 153 insertions(+), 79 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java
index 5ded4bc9a76..c99071a714e 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java
@@ -44,6 +44,7 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.Set;
/**
@@ -55,9 +56,9 @@ public class SumLiteralRewrite extends OneRewriteRuleFactory {
return logicalAggregate()
.whenNot(agg -> agg.getSourceRepeat().isPresent())
.then(agg -> {
- Map<NamedExpression, Pair<Expression, Literal>>
sumLiteralMap = new HashMap<>();
+ Map<NamedExpression, Pair<SumInfo, Literal>> sumLiteralMap
= new HashMap<>();
for (NamedExpression namedExpression : agg.getOutputs()) {
- Pair<NamedExpression, Pair<Expression, Literal>> pel =
extractSumLiteral(namedExpression);
+ Pair<NamedExpression, Pair<SumInfo, Literal>> pel =
extractSumLiteral(namedExpression);
if (pel == null) {
continue;
}
@@ -71,7 +72,7 @@ public class SumLiteralRewrite extends OneRewriteRuleFactory {
}
private Plan rewriteSumLiteral(
- LogicalAggregate<?> agg, Map<NamedExpression, Pair<Expression,
Literal>> sumLiteralMap) {
+ LogicalAggregate<?> agg, Map<NamedExpression, Pair<SumInfo,
Literal>> sumLiteralMap) {
Set<NamedExpression> newAggOutput = new HashSet<>();
for (NamedExpression expr : agg.getOutputExpressions()) {
if (!sumLiteralMap.containsKey(expr)) {
@@ -79,8 +80,8 @@ public class SumLiteralRewrite extends OneRewriteRuleFactory {
}
}
- Map<Expression, Slot> exprToSum = new HashMap<>();
- Map<Expression, Slot> exprToCount = new HashMap<>();
+ Map<SumInfo, Slot> exprToSum = new HashMap<>();
+ Map<SumInfo, Slot> exprToCount = new HashMap<>();
Map<AggregateFunction, NamedExpression> existedAggFunc = new
HashMap<>();
for (NamedExpression e : agg.getOutputExpressions()) {
@@ -89,16 +90,16 @@ public class SumLiteralRewrite extends
OneRewriteRuleFactory {
}
}
- Set<Expression> countSumExpr = new HashSet<>();
- for (Pair<Expression, Literal> pair : sumLiteralMap.values()) {
+ Set<SumInfo> countSumExpr = new HashSet<>();
+ for (Pair<SumInfo, Literal> pair : sumLiteralMap.values()) {
countSumExpr.add(pair.first);
}
- for (Expression e : countSumExpr) {
- NamedExpression namedSum = constructSum(e, existedAggFunc);
- NamedExpression namedCount = constructCount(e, existedAggFunc);
- exprToSum.put(e, namedSum.toSlot());
- exprToCount.put(e, namedCount.toSlot());
+ for (SumInfo info : countSumExpr) {
+ NamedExpression namedSum = constructSum(info, existedAggFunc);
+ NamedExpression namedCount = constructCount(info, existedAggFunc);
+ exprToSum.put(info, namedSum.toSlot());
+ exprToCount.put(info, namedCount.toSlot());
newAggOutput.add(namedSum);
newAggOutput.add(namedCount);
}
@@ -111,15 +112,15 @@ public class SumLiteralRewrite extends
OneRewriteRuleFactory {
}
private List<NamedExpression> constructProjectExpression(
- LogicalAggregate<?> agg, Map<NamedExpression, Pair<Expression,
Literal>> sumLiteralMap,
- Map<Expression, Slot> exprToSum, Map<Expression, Slot>
exprToCount) {
+ LogicalAggregate<?> agg, Map<NamedExpression, Pair<SumInfo,
Literal>> sumLiteralMap,
+ Map<SumInfo, Slot> exprToSum, Map<SumInfo, Slot> exprToCount) {
List<NamedExpression> newProjects = new ArrayList<>();
for (NamedExpression namedExpr : agg.getOutputExpressions()) {
if (!sumLiteralMap.containsKey(namedExpr)) {
newProjects.add(namedExpr.toSlot());
continue;
}
- Expression originExpr = sumLiteralMap.get(namedExpr).first;
+ SumInfo originExpr = sumLiteralMap.get(namedExpr).first;
Literal literal = sumLiteralMap.get(namedExpr).second;
Expression newExpr;
if (namedExpr.child(0).child(0) instanceof Add) {
@@ -134,8 +135,8 @@ public class SumLiteralRewrite extends
OneRewriteRuleFactory {
return newProjects;
}
- private NamedExpression constructSum(Expression child,
Map<AggregateFunction, NamedExpression> existedAggFunc) {
- Sum sum = new Sum(child);
+ private NamedExpression constructSum(SumInfo info, Map<AggregateFunction,
NamedExpression> existedAggFunc) {
+ Sum sum = new Sum(info.isDistinct, info.isAlwaysNullable, info.expr);
NamedExpression namedSum;
if (existedAggFunc.containsKey(sum)) {
namedSum = existedAggFunc.get(sum);
@@ -145,8 +146,8 @@ public class SumLiteralRewrite extends
OneRewriteRuleFactory {
return namedSum;
}
- private NamedExpression constructCount(Expression child,
Map<AggregateFunction, NamedExpression> existedAggFunc) {
- Count count = new Count(child);
+ private NamedExpression constructCount(SumInfo info,
Map<AggregateFunction, NamedExpression> existedAggFunc) {
+ Count count = new Count(info.isDistinct, info.expr);
NamedExpression namedCount;
if (existedAggFunc.containsKey(count)) {
namedCount = existedAggFunc.get(count);
@@ -156,7 +157,7 @@ public class SumLiteralRewrite extends
OneRewriteRuleFactory {
return namedCount;
}
- private @Nullable Pair<NamedExpression, Pair<Expression, Literal>>
extractSumLiteral(
+ private @Nullable Pair<NamedExpression, Pair<SumInfo, Literal>>
extractSumLiteral(
NamedExpression namedExpression) {
if (namedExpression.children().size() != 1) {
return null;
@@ -180,6 +181,44 @@ public class SumLiteralRewrite extends
OneRewriteRuleFactory {
// only support integer or float types
return null;
}
- return Pair.of(namedExpression, Pair.of(left, (Literal) right));
+ SumInfo info = new SumInfo(left, ((Sum) func).isDistinct(), ((Sum)
func).isAlwaysNullable());
+ return Pair.of(namedExpression, Pair.of(info, (Literal) right));
+ }
+
+ static class SumInfo {
+ Expression expr;
+ boolean isDistinct;
+ boolean isAlwaysNullable;
+
+ SumInfo(Expression expr, boolean isDistinct, boolean isAlwaysNullable)
{
+ this.expr = expr;
+ this.isDistinct = isDistinct;
+ this.isAlwaysNullable = isAlwaysNullable;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ SumInfo sumInfo = (SumInfo) o;
+
+ if (isDistinct != sumInfo.isDistinct) {
+ return false;
+ }
+ if (isAlwaysNullable != sumInfo.isAlwaysNullable) {
+ return false;
+ }
+ return Objects.equals(expr, sumInfo.expr);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(expr, isDistinct, isAlwaysNullable);
+ }
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java
index 97dda3f8cb9..cb2cc77627e 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java
@@ -37,18 +37,79 @@ class SumLiteralRewriteTest implements
MemoPatternMatchSupported {
private final LogicalOlapScan scan1 =
PlanConstructor.newLogicalOlapScan(0, "t1", 0);
@Test
- void testSimpleAddSum() {
+ void testSimpleSum() {
Slot slot1 = scan1.getOutput().get(0);
- Alias sum = new Alias(new Sum(slot1));
Alias add1 = new Alias(new Sum(new Add(slot1, Literal.of(1))));
Alias add2 = new Alias(new Sum(new Add(slot1, Literal.of(2))));
Alias sub1 = new Alias(new Sum(new Subtract(slot1, Literal.of(1))));
Alias sub2 = new Alias(new Sum(new Subtract(slot1, Literal.of(2))));
LogicalAggregate<?> agg = new LogicalAggregate<>(
+ ImmutableList.of(scan1.getOutput().get(0)),
ImmutableList.of(add1, add2, sub1, sub2), scan1);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyTopDown(ImmutableList.of(new
SumLiteralRewrite().build()))
+ .printlnTree()
+ .matches(logicalAggregate().when(p -> p.getOutputs().size() ==
2));
+
+ Alias sum = new Alias(new Sum(slot1));
+ agg = new LogicalAggregate<>(
ImmutableList.of(scan1.getOutput().get(0)),
ImmutableList.of(sum, add1, add2, sub1, sub2), scan1);
PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
.applyTopDown(ImmutableList.of(new
SumLiteralRewrite().build()))
.printlnTree()
- .matches(logicalAggregate().when(p ->
p.getAggregateFunctions().size() == 2));
+ .matches(logicalAggregate().when(p -> p.getOutputs().size() ==
2));
+ }
+
+ @Test
+ void testSumNullable() {
+ Slot slot1 = scan1.getOutput().get(0);
+ Alias add1 = new Alias(new Sum(false, true, new Add(slot1,
Literal.of(1))));
+ Alias add2 = new Alias(new Sum(false, true, new Add(slot1,
Literal.of(2))));
+ Alias sub1 = new Alias(new Sum(false, true, new Subtract(slot1,
Literal.of(1))));
+ Alias sub2 = new Alias(new Sum(false, true, new Subtract(slot1,
Literal.of(2))));
+ LogicalAggregate<?> agg = new LogicalAggregate<>(
+ ImmutableList.of(scan1.getOutput().get(0)),
ImmutableList.of(add1, add2, sub1, sub2), scan1);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyTopDown(ImmutableList.of(new
SumLiteralRewrite().build()))
+ .printlnTree()
+ .matches(logicalAggregate().when(p -> p.getOutputs().size() ==
2));
+
+ Alias sum = new Alias(new Sum(false, true, slot1));
+ agg = new LogicalAggregate<>(
+ ImmutableList.of(scan1.getOutput().get(0)),
ImmutableList.of(sum, add1, add2, sub1, sub2), scan1);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyTopDown(ImmutableList.of(new
SumLiteralRewrite().build()))
+ .printlnTree()
+ .matches(logicalAggregate().when(p -> p.getOutputs().size() ==
2));
+ }
+
+ @Test
+ void testSumDistinct() {
+ Slot slot1 = scan1.getOutput().get(0);
+ Alias add1 = new Alias(new Sum(true, true, new Add(slot1,
Literal.of(1))));
+ Alias add2 = new Alias(new Sum(false, true, new Add(slot1,
Literal.of(2))));
+ Alias sub1 = new Alias(new Sum(true, true, new Subtract(slot1,
Literal.of(1))));
+ Alias sub2 = new Alias(new Sum(false, true, new Subtract(slot1,
Literal.of(2))));
+ LogicalAggregate<?> agg = new LogicalAggregate<>(
+ ImmutableList.of(scan1.getOutput().get(0)),
ImmutableList.of(add1, add2, sub1, sub2), scan1);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyTopDown(ImmutableList.of(new
SumLiteralRewrite().build()))
+ .printlnTree()
+ .matches(logicalAggregate().when(p -> p.getOutputs().size() ==
4));
+
+ Alias sumDistinct = new Alias(new Sum(true, true, slot1));
+ agg = new LogicalAggregate<>(
+ ImmutableList.of(scan1.getOutput().get(0)),
ImmutableList.of(sumDistinct, add1, add2, sub1, sub2), scan1);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyTopDown(ImmutableList.of(new
SumLiteralRewrite().build()))
+ .printlnTree()
+ .matches(logicalAggregate().when(p -> p.getOutputs().size() ==
4));
+
+ Alias sum = new Alias(new Sum(false, true, slot1));
+ agg = new LogicalAggregate<>(
+ ImmutableList.of(scan1.getOutput().get(0)),
ImmutableList.of(sumDistinct, sum, add1, add2, sub1, sub2), scan1);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyTopDown(ImmutableList.of(new
SumLiteralRewrite().build()))
+ .printlnTree()
+ .matches(logicalAggregate().when(p -> p.getOutputs().size() ==
4));
}
}
diff --git a/regression-test/data/nereids_rules_p0/sumRewrite.out
b/regression-test/data/nereids_rules_p0/sumRewrite.out
index ddb4b90175d..65c174bae07 100644
--- a/regression-test/data/nereids_rules_p0/sumRewrite.out
+++ b/regression-test/data/nereids_rules_p0/sumRewrite.out
@@ -100,43 +100,12 @@
8 8.099999904632568
9 12.5
--- !decimal_sum_add_const_precision_1$ --
-2670.55
-
--- !decimal_sum_add_const_precision_2$ --
-2672.55
-
--- !decimal_sum_add_const_precision_3$ --
-10 694.19
-6 434.03
-7 474.07
-8 514.11
-9 554.15
-
--- !decimal_sum_add_const_precision_4$ --
-10 694.636
-6 434.476
-7 474.516
-8 514.556
-9 554.596
-
--- !decimal_sum_sub_const_precision_1$ --
-2630.55
-
--- !decimal_sum_sub_const_precision_2$ --
-2628.55
-
--- !decimal_sum_sub_const_precision_3$ --
-10 686.19
-6 426.03
-7 466.07
-8 506.11
-9 546.15
-
--- !decimal_sum_sub_const_precision_4$ --
-10 685.744
-6 425.584
-7 465.624
-8 505.664
-9 545.704
+-- !sum_null_and_not_null --
+122 130 114 80 90 70
+
+-- !sum_distinct --
+122 122 130 130
+
+-- !sum_not_null_distinct --
+80 40 90 45
diff --git a/regression-test/suites/nereids_rules_p0/sumRewrite.groovy
b/regression-test/suites/nereids_rules_p0/sumRewrite.groovy
index 6a6e8a02c87..e6d03c52f23 100644
--- a/regression-test/suites/nereids_rules_p0/sumRewrite.groovy
+++ b/regression-test/suites/nereids_rules_p0/sumRewrite.groovy
@@ -25,7 +25,7 @@ suite("sumRewrite") {
sql """
CREATE TABLE IF NOT EXISTS sr(
`id` int NULL,
- `null_id` int not NULL,
+ `not_null_id` int not NULL,
`f_id` float NULL,
`d_id` decimal(10,2),
) ENGINE = OLAP
@@ -36,7 +36,7 @@ suite("sumRewrite") {
"""
sql """
-INSERT INTO sr (id, null_id, f_id, d_id) VALUES
+INSERT INTO sr (id, not_null_id, f_id, d_id) VALUES
(11, 6, 1.1, 210.01),
(12, 6, 2.2, 220.02),
(13, 7, 3.3, 230.03),
@@ -55,9 +55,9 @@ INSERT INTO sr (id, null_id, f_id, d_id) VALUES
order_qt_sum_add_const_where$ """ select sum(id + 2) from sr where id is
not null """
- order_qt_sum_add_const_group_by$ """ select null_id, sum(id + 2) from sr
group by null_id """
+ order_qt_sum_add_const_group_by$ """ select not_null_id, sum(id + 2) from
sr group by not_null_id """
- order_qt_sum_add_const_having$ """ select null_id, sum(id + 2) from sr
group by null_id having sum(id + 2) > 5 """
+ order_qt_sum_add_const_having$ """ select not_null_id, sum(id + 2) from sr
group by not_null_id having sum(id + 2) > 5 """
order_qt_sum_sub_const$ """ select sum(id - 2) from sr """
@@ -65,17 +65,17 @@ INSERT INTO sr (id, null_id, f_id, d_id) VALUES
order_qt_sum_sub_const_where$ """ select sum(id - 2) from sr where id is
not null """
- order_qt_sum_sub_const_group_by$ """ select null_id, sum(id - 2) from sr
group by null_id """
+ order_qt_sum_sub_const_group_by$ """ select not_null_id, sum(id - 2) from
sr group by not_null_id """
- order_qt_sum_sub_const_having$ """ select null_id, sum(id - 2) from sr
group by null_id having sum(id - 2) > 0 """
+ order_qt_sum_sub_const_having$ """ select not_null_id, sum(id - 2) from sr
group by not_null_id having sum(id - 2) > 0 """
order_qt_sum_add_const_empty_table$ """ select sum(id + 2) from sr where
1=0 """
- order_qt_sum_add_const_empty_table_group_by$ """ select null_id, sum(id +
2) from sr where 1=0 group by null_id """
+ order_qt_sum_add_const_empty_table_group_by$ """ select not_null_id,
sum(id + 2) from sr where 1=0 group by not_null_id """
order_qt_sum_sub_const_empty_table$ """ select sum(id - 2) from sr where
1=0 """
- order_qt_sum_sub_const_empty_table_group_by$ """ select null_id, sum(id -
2) from sr where 1=0 group by null_id """
+ order_qt_sum_sub_const_empty_table_group_by$ """ select not_null_id,
sum(id - 2) from sr where 1=0 group by not_null_id """
// float类型字段测试
order_qt_float_sum_add_const$ """ select sum(f_id + 2) from sr """
@@ -84,9 +84,9 @@ INSERT INTO sr (id, null_id, f_id, d_id) VALUES
order_qt_float_sum_add_const_where$ """ select sum(f_id + 2) from sr where
f_id is not null """
- order_qt_float_sum_add_const_group_by$ """ select null_id, sum(f_id + 2)
from sr group by null_id """
+ order_qt_float_sum_add_const_group_by$ """ select not_null_id, sum(f_id +
2) from sr group by not_null_id """
- order_qt_float_sum_add_const_having$ """ select null_id, sum(f_id + 2)
from sr group by null_id having sum(f_id + 2) > 5 """
+ order_qt_float_sum_add_const_having$ """ select not_null_id, sum(f_id + 2)
from sr group by not_null_id having sum(f_id + 2) > 5 """
order_qt_float_sum_sub_const$ """ select sum(f_id - 2) from sr """
@@ -94,25 +94,30 @@ INSERT INTO sr (id, null_id, f_id, d_id) VALUES
order_qt_float_sum_sub_const_where$ """ select sum(f_id - 2) from sr where
f_id is not null """
- order_qt_float_sum_sub_const_group_by$ """ select null_id, sum(f_id - 2)
from sr group by null_id """
+ order_qt_float_sum_sub_const_group_by$ """ select not_null_id, sum(f_id -
2) from sr group by not_null_id """
- order_qt_float_sum_sub_const_having$ """ select null_id, sum(f_id - 2)
from sr group by null_id having sum(f_id - 2) > 0 """
+ order_qt_float_sum_sub_const_having$ """ select not_null_id, sum(f_id - 2)
from sr group by not_null_id having sum(f_id - 2) > 0 """
+ order_qt_sum_null_and_not_null """ select sum(id), sum(id + 1), sum(id -
1), sum(not_null_id), sum(not_null_id + 1), sum(not_null_id - 1) from sr"""
+
+ order_qt_sum_distinct """select sum(id), sum(distinct id), sum(id + 1),
sum(distinct id + 1) from sr"""
+
+ order_qt_sum_not_null_distinct """select sum(not_null_id), sum(distinct
not_null_id), sum(not_null_id + 1), sum(distinct not_null_id + 1) from sr"""
// 测试精度变化对sum加常数的影响
// order_qt_decimal_sum_add_const_precision_1$ """ select sum(d_id + 2)
from sr """
// order_qt_decimal_sum_add_const_precision_2$ """ select sum(d_id + 2.2)
from sr """
- // order_qt_decimal_sum_add_const_precision_3$ """ select null_id,
sum(d_id + 2) from sr group by null_id """
+ // order_qt_decimal_sum_add_const_precision_3$ """ select not_null_id,
sum(d_id + 2) from sr group by not_null_id """
- // order_qt_decimal_sum_add_const_precision_4$ """ select null_id,
sum(d_id + 2.223) from sr group by null_id """
+ // order_qt_decimal_sum_add_const_precision_4$ """ select not_null_id,
sum(d_id + 2.223) from sr group by not_null_id """
// 测试精度变化对sum减常数的影响
// order_qt_decimal_sum_sub_const_precision_1$ """ select sum(d_id - 2)
from sr """
// order_qt_decimal_sum_sub_const_precision_2$ """ select sum(d_id - 2.2)
from sr """
- // order_qt_decimal_sum_sub_const_precision_3$ """ select null_id,
sum(d_id - 2) from sr group by null_id """
+ // order_qt_decimal_sum_sub_const_precision_3$ """ select not_null_id,
sum(d_id - 2) from sr group by not_null_id """
- // order_qt_decimal_sum_sub_const_precision_4$ """ select null_id,
sum(d_id - 2.223) from sr group by null_id """
+ // order_qt_decimal_sum_sub_const_precision_4$ """ select not_null_id,
sum(d_id - 2.223) from sr group by not_null_id """
}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]