This is an automated email from the ASF dual-hosted git repository.
mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 283c1c111d [CALCITE-6333] NullPointerException in
AggregateExpandDistinctAggregatesRule.doRewrite when rewriting filtered
distinct aggregation
283c1c111d is described below
commit 283c1c111d2d2682c728f681bebc5307d40d70a5
Author: abhishekagarwal87 <[email protected]>
AuthorDate: Fri Mar 15 15:49:34 2024 +0530
[CALCITE-6333] NullPointerException in
AggregateExpandDistinctAggregatesRule.doRewrite when rewriting filtered
distinct aggregation
Fix test order
Fix one more
---
.../AggregateExpandDistinctAggregatesRule.java | 9 ++--
.../org/apache/calcite/test/RelOptRulesTest.java | 21 +++++++++
.../org/apache/calcite/test/RelOptRulesTest.xml | 54 ++++++++++++++++++++++
3 files changed, 79 insertions(+), 5 deletions(-)
diff --git
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
index ccf32284e7..56e224572c 100644
---
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
+++
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
@@ -741,6 +741,9 @@ public final class AggregateExpandDistinctAggregatesRule
if (!aggCall.getArgList().equals(argList)) {
continue;
}
+ if (aggCall.filterArg != filterArg) {
+ continue;
+ }
// Re-map arguments.
final int argCount = aggCall.getArgList().size();
@@ -748,14 +751,10 @@ public final class AggregateExpandDistinctAggregatesRule
for (Integer arg : aggCall.getArgList()) {
newArgs.add(requireNonNull(sourceOf.get(arg), () -> "sourceOf.get(" +
arg + ")"));
}
- final int newFilterArg =
- aggCall.filterArg < 0 ? -1
- : requireNonNull(sourceOf.get(aggCall.filterArg),
- () -> "sourceOf.get(" + aggCall.filterArg + ")");
final AggregateCall newAggCall =
AggregateCall.create(aggCall.getAggregation(), false,
aggCall.isApproximate(), aggCall.ignoreNulls(), aggCall.rexList,
- newArgs, newFilterArg, aggCall.distinctKeys, aggCall.collation,
+ newArgs, -1, aggCall.distinctKeys, aggCall.collation,
aggCall.getType(), aggCall.getName());
assert refs.get(i) == null;
if (leftFields == null) {
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 0681d4b7fc..d02cf05c1d 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -2175,6 +2175,27 @@ class RelOptRulesTest extends RelOptTestBase {
.check();
}
+ @Test void testDistinctWithFilterWithoutGroupByUsingJoin() {
+ final String sql = "SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE
sal > 1000)\n"
+ + "FROM emp";
+ sql(sql)
+ .withRule(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN)
+ .check();
+ }
+
+ @Test void testMultipleDistinctWithSameArgsDifferentFilterUsingJoin() {
+ final String sql = "select deptno, "
+ + "count(distinct sal) FILTER (WHERE sal > 1000), "
+ + "count(distinct sal) FILTER (WHERE sal > 500) "
+ + "from sales.emp group by deptno";
+ sql(sql)
+ .withRule(
+ CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN,
+ CoreRules.AGGREGATE_PROJECT_MERGE
+ )
+ .check();
+ }
+
@Test void testDistinctWithFilterWithoutGroupBy() {
final String sql = "SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE
sal > 1000)\n"
+ "FROM emp";
diff --git
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 177abc997b..bb0c2f9ca8 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -2873,6 +2873,32 @@ LogicalAggregate(group=[{}], EXPR$0=[MIN($1) FILTER $3],
EXPR$1=[COUNT($0) FILTE
LogicalAggregate(group=[{1, 2}], groups=[[{1, 2}, {}]], EXPR$0=[SUM($0)],
$g=[GROUPING($1, $2)])
LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testDistinctWithFilterWithoutGroupByUsingJoin">
+ <Resource name="sql">
+ <![CDATA[SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE sal > 1000)
+FROM emp]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($0)], EXPR$1=[COUNT(DISTINCT $1)
FILTER $2])
+ LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalJoin(condition=[true], joinType=[inner])
+ LogicalAggregate(group=[{}], EXPR$0=[SUM($0)])
+ LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalAggregate(group=[{}], EXPR$1=[COUNT($0)])
+ LogicalAggregate(group=[{0}])
+ LogicalProject(i$SAL=[CASE($2, $1, null:INTEGER)])
+ LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
@@ -7279,6 +7305,34 @@ LogicalProject(SAL=[$0], EXPR$1=[$1], EXPR$2=[$3],
EXPR$3=[$5])
LogicalProject(SAL=[$0])
LogicalProject(SAL=[$5], COMM=[$6])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testMultipleDistinctWithSameArgsDifferentFilterUsingJoin">
+ <Resource name="sql">
+ <![CDATA[select deptno, count(distinct sal) FILTER (WHERE sal > 1000),
count(distinct sal) FILTER (WHERE sal > 500) from sales.emp group by deptno]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[COUNT(DISTINCT $1) FILTER $2],
EXPR$2=[COUNT(DISTINCT $1) FILTER $3])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($5, 1000)], $f3=[>($5, 500)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[$1], EXPR$2=[$3])
+ LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $2)], joinType=[inner])
+ LogicalAggregate(group=[{0}], EXPR$1=[COUNT($1)])
+ LogicalAggregate(group=[{0, 1}])
+ LogicalProject(DEPTNO=[$0], i$SAL=[CASE($2, $1, null:INTEGER)])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($5, 1000)], $f3=[>($5,
500)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalAggregate(group=[{0}], EXPR$2=[COUNT($1)])
+ LogicalAggregate(group=[{0, 1}])
+ LogicalProject(DEPTNO=[$0], i$SAL=[CASE($3, $1, null:INTEGER)])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($5, 1000)], $f3=[>($5,
500)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>