This is an automated email from the ASF dual-hosted git repository.

mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 283c1c111d [CALCITE-6333] NullPointerException in 
AggregateExpandDistinctAggregatesRule.doRewrite when rewriting filtered 
distinct aggregation
283c1c111d is described below

commit 283c1c111d2d2682c728f681bebc5307d40d70a5
Author: abhishekagarwal87 <[email protected]>
AuthorDate: Fri Mar 15 15:49:34 2024 +0530

    [CALCITE-6333] NullPointerException in 
AggregateExpandDistinctAggregatesRule.doRewrite when rewriting filtered 
distinct aggregation
    
    Fix test order
    
    Fix one more
---
 .../AggregateExpandDistinctAggregatesRule.java     |  9 ++--
 .../org/apache/calcite/test/RelOptRulesTest.java   | 21 +++++++++
 .../org/apache/calcite/test/RelOptRulesTest.xml    | 54 ++++++++++++++++++++++
 3 files changed, 79 insertions(+), 5 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
index ccf32284e7..56e224572c 100644
--- 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
+++ 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
@@ -741,6 +741,9 @@ public final class AggregateExpandDistinctAggregatesRule
       if (!aggCall.getArgList().equals(argList)) {
         continue;
       }
+      if (aggCall.filterArg != filterArg) {
+        continue;
+      }
 
       // Re-map arguments.
       final int argCount = aggCall.getArgList().size();
@@ -748,14 +751,10 @@ public final class AggregateExpandDistinctAggregatesRule
       for (Integer arg : aggCall.getArgList()) {
         newArgs.add(requireNonNull(sourceOf.get(arg), () -> "sourceOf.get(" + 
arg + ")"));
       }
-      final int newFilterArg =
-          aggCall.filterArg < 0 ? -1
-              : requireNonNull(sourceOf.get(aggCall.filterArg),
-                  () -> "sourceOf.get(" + aggCall.filterArg + ")");
       final AggregateCall newAggCall =
           AggregateCall.create(aggCall.getAggregation(), false,
               aggCall.isApproximate(), aggCall.ignoreNulls(), aggCall.rexList,
-              newArgs, newFilterArg, aggCall.distinctKeys, aggCall.collation,
+              newArgs, -1, aggCall.distinctKeys, aggCall.collation,
               aggCall.getType(), aggCall.getName());
       assert refs.get(i) == null;
       if (leftFields == null) {
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java 
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 0681d4b7fc..d02cf05c1d 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -2175,6 +2175,27 @@ class RelOptRulesTest extends RelOptTestBase {
         .check();
   }
 
+  @Test void testDistinctWithFilterWithoutGroupByUsingJoin() {
+    final String sql = "SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE 
sal > 1000)\n"
+                       + "FROM emp";
+    sql(sql)
+        .withRule(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN)
+        .check();
+  }
+
+  @Test void testMultipleDistinctWithSameArgsDifferentFilterUsingJoin() {
+    final String sql = "select deptno, "
+                       + "count(distinct sal) FILTER (WHERE sal > 1000), "
+                       + "count(distinct sal) FILTER (WHERE sal > 500) "
+                       + "from sales.emp group by deptno";
+    sql(sql)
+        .withRule(
+            CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN,
+            CoreRules.AGGREGATE_PROJECT_MERGE
+        )
+        .check();
+  }
+
   @Test void testDistinctWithFilterWithoutGroupBy() {
     final String sql = "SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE 
sal > 1000)\n"
         + "FROM emp";
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 177abc997b..bb0c2f9ca8 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -2873,6 +2873,32 @@ LogicalAggregate(group=[{}], EXPR$0=[MIN($1) FILTER $3], 
EXPR$1=[COUNT($0) FILTE
     LogicalAggregate(group=[{1, 2}], groups=[[{1, 2}, {}]], EXPR$0=[SUM($0)], 
$g=[GROUPING($1, $2)])
       LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
         LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testDistinctWithFilterWithoutGroupByUsingJoin">
+    <Resource name="sql">
+      <![CDATA[SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE sal > 1000)
+FROM emp]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($0)], EXPR$1=[COUNT(DISTINCT $1) 
FILTER $2])
+  LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalJoin(condition=[true], joinType=[inner])
+  LogicalAggregate(group=[{}], EXPR$0=[SUM($0)])
+    LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+  LogicalAggregate(group=[{}], EXPR$1=[COUNT($0)])
+    LogicalAggregate(group=[{0}])
+      LogicalProject(i$SAL=[CASE($2, $1, null:INTEGER)])
+        LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
+          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
     </Resource>
   </TestCase>
@@ -7279,6 +7305,34 @@ LogicalProject(SAL=[$0], EXPR$1=[$1], EXPR$2=[$3], 
EXPR$3=[$5])
         LogicalProject(SAL=[$0])
           LogicalProject(SAL=[$5], COMM=[$6])
             LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testMultipleDistinctWithSameArgsDifferentFilterUsingJoin">
+    <Resource name="sql">
+      <![CDATA[select deptno, count(distinct sal) FILTER (WHERE sal > 1000), 
count(distinct sal) FILTER (WHERE sal > 500) from sales.emp group by deptno]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[COUNT(DISTINCT $1) FILTER $2], 
EXPR$2=[COUNT(DISTINCT $1) FILTER $3])
+  LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($5, 1000)], $f3=[>($5, 500)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[$1], EXPR$2=[$3])
+  LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $2)], joinType=[inner])
+    LogicalAggregate(group=[{0}], EXPR$1=[COUNT($1)])
+      LogicalAggregate(group=[{0, 1}])
+        LogicalProject(DEPTNO=[$0], i$SAL=[CASE($2, $1, null:INTEGER)])
+          LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($5, 1000)], $f3=[>($5, 
500)])
+            LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalAggregate(group=[{0}], EXPR$2=[COUNT($1)])
+      LogicalAggregate(group=[{0, 1}])
+        LogicalProject(DEPTNO=[$0], i$SAL=[CASE($3, $1, null:INTEGER)])
+          LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($5, 1000)], $f3=[>($5, 
500)])
+            LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
     </Resource>
   </TestCase>

Reply via email to