This is an automated email from the ASF dual-hosted git repository.

kgyrtkirk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 05787a6793 [CALCITE-5953] AggregateCaseToFilterRule may make 
inaccurate SUM transformations
05787a6793 is described below

commit 05787a6793d5393ee05c99182e05823d12edab12
Author: Zoltan Haindrich <[email protected]>
AuthorDate: Mon Aug 28 16:21:09 2023 +0000

    [CALCITE-5953] AggregateCaseToFilterRule may make inaccurate SUM 
transformations
    
    The aggregate: SUM(case when x = 1 then 1 else 0 end)
    * must be null only in case there are no input rows
    * may only be 0 if there is at least one input which doesn't match the 
filter
---
 .../rel/rules/AggregateCaseToFilterRule.java       | 18 +++----
 .../org/apache/calcite/test/RelOptRulesTest.java   |  9 ++++
 .../org/apache/calcite/test/RelOptRulesTest.xml    | 23 +++++++--
 core/src/test/resources/sql/agg.iq                 | 55 ++++++++++++++++++++++
 4 files changed, 92 insertions(+), 13 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateCaseToFilterRule.java
 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateCaseToFilterRule.java
index f4642855ac..303b2d2c78 100644
--- 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateCaseToFilterRule.java
+++ 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateCaseToFilterRule.java
@@ -196,14 +196,14 @@ public class AggregateCaseToFilterRule
 
     // Four styles supported:
     //
-    // A1: AGG(CASE WHEN x = 'foo' THEN cnt END)
-    //   => operands (x = 'foo', cnt, null)
-    // A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END)
-    //   => operands (x = 'foo', cnt, 0); must be SUM
-    // B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END)
-    //   => operands (x = 'foo', 1, 0); must be SUM
+    // A1: AGG(CASE WHEN x = 'foo' THEN expr END)
+    //   => AGG(expr) FILTER (x = 'foo')
+    // A2: SUM0(CASE WHEN x = 'foo' THEN cnt ELSE 0 END)
+    //   => SUM0(cnt) FILTER (x = 'foo')
+    // B: SUM0(CASE WHEN x = 'foo' THEN 1 ELSE 0 END)
+    //   => COUNT() FILTER (x = 'foo')
     // C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END)
-    //   => operands (x = 'foo', 'dummy', null)
+    //   => COUNT() FILTER (x = 'foo')
 
     if (kind == SqlKind.COUNT // Case C
         && arg1.isA(SqlKind.LITERAL)
@@ -214,7 +214,7 @@ public class AggregateCaseToFilterRule
           false, call.rexList, ImmutableList.of(), newProjects.size() - 1, 
null,
           RelCollations.EMPTY, call.getType(),
           call.getName());
-    } else if (kind == SqlKind.SUM // Case B
+    } else if (kind == SqlKind.SUM0 // Case B
         && isIntLiteral(arg1, BigDecimal.ONE)
         && isIntLiteral(arg2, BigDecimal.ZERO)) {
 
@@ -228,7 +228,7 @@ public class AggregateCaseToFilterRule
           RelCollations.EMPTY, dataType, call.getName());
     } else if ((RexLiteral.isNullLiteral(arg2) // Case A1
             && call.getAggregation().allowsFilter())
-        || (kind == SqlKind.SUM // Case A2
+        || (kind == SqlKind.SUM0 // Case A2
             && isIntLiteral(arg2, BigDecimal.ZERO))) {
       newProjects.add(arg1);
       newProjects.add(filter);
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java 
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 0bb72d07f4..a7b4088219 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -4411,6 +4411,15 @@ class RelOptRulesTest extends RelOptTestBase {
     sql(sql).withRule(CoreRules.AGGREGATE_CASE_TO_FILTER).check();
   }
 
+  @Test void testAggregateCaseToFilterNoMatch() {
+    final String sql = "select\n"
+        + " sum(case when deptno = -1 then 1 else 0 end) as sum_no_match,\n"
+        + " sum(case when deptno = -1 then 2 else 0 end) as sum_no_match2,\n"
+        + " sum(case when deptno = -1 then 3 else -1 end) as sum_no_match3\n"
+        + "from emp";
+    sql(sql).withRule(CoreRules.AGGREGATE_CASE_TO_FILTER).checkUnchanged();
+  }
+
   @Test void testPullAggregateThroughUnion() {
     final String sql = "select deptno, job from"
         + " (select deptno, job from emp as e1"
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index f2ae0acdbe..8b965e3c6a 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -66,10 +66,25 @@ LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], 
COUNT_DISTINCT_CLERK=[COUNT(DIST
     </Resource>
     <Resource name="planAfter">
       <![CDATA[
-LogicalProject(SUM_SAL=[$0], COUNT_DISTINCT_CLERK=[$1], SUM_SAL_D10=[$2], 
SUM_SAL_D20=[$3], COUNT_D30=[CAST($4):INTEGER], COUNT_D40=[$5], COUNT_D45=[$6], 
COUNT_D50=[$7], SUM_NULL_D60=[$8], SUM_NULL_D70=[$9], COUNT_D20=[$10])
-  LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], 
COUNT_DISTINCT_CLERK=[COUNT(DISTINCT $2) FILTER $3], SUM_SAL_D10=[SUM($4) 
FILTER $5], SUM_SAL_D20=[SUM($6) FILTER $7], COUNT_D30=[COUNT() FILTER $8], 
COUNT_D40=[COUNT() FILTER $9], COUNT_D45=[SUM($10) FILTER $11], 
COUNT_D50=[SUM($12) FILTER $13], SUM_NULL_D60=[SUM($1)], SUM_NULL_D70=[SUM($14) 
FILTER $15], COUNT_D20=[COUNT() FILTER $16])
-    LogicalProject(SAL=[$5], $f8=[null:DECIMAL(19, 9)], DEPTNO=[$7], 
$f12=[=($2, 'CLERK')], SAL0=[$5], $f14=[=($7, 10)], SAL1=[$5], $f16=[=($7, 
20)], $f17=[=($7, 30)], $f18=[=($7, 40)], $f19=[1], $f20=[=($7, 45)], $f21=[1], 
$f22=[=($7, 50)], $f23=[1], $f24=[<>($7, 70)], $f25=[=($7, 20)])
-      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], 
COUNT_DISTINCT_CLERK=[COUNT(DISTINCT $4) FILTER $5], SUM_SAL_D10=[SUM($6) 
FILTER $7], SUM_SAL_D20=[SUM($1)], COUNT_D30=[SUM($2)], COUNT_D40=[COUNT() 
FILTER $8], COUNT_D45=[SUM($9) FILTER $10], COUNT_D50=[SUM($11) FILTER $12], 
SUM_NULL_D60=[SUM($3)], SUM_NULL_D70=[SUM($13) FILTER $14], COUNT_D20=[COUNT() 
FILTER $15])
+  LogicalProject(SAL=[$5], $f3=[CASE(=($7, 20), $5, 0)], $f4=[CASE(=($7, 30), 
1, 0)], $f8=[null:DECIMAL(19, 9)], DEPTNO=[$7], $f12=[=($2, 'CLERK')], 
SAL0=[$5], $f14=[=($7, 10)], $f15=[=($7, 40)], $f16=[1], $f17=[=($7, 45)], 
$f18=[1], $f19=[=($7, 50)], $f20=[1], $f21=[<>($7, 70)], $f22=[=($7, 20)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testAggregateCaseToFilterNoMatch">
+    <Resource name="sql">
+      <![CDATA[select
+ sum(case when deptno = -1 then 1 else 0 end) as sum_no_match,
+ sum(case when deptno = -1 then 2 else 0 end) as sum_no_match2,
+ sum(case when deptno = -1 then 3 else -1 end) as sum_no_match3
+from emp]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{}], SUM_NO_MATCH=[SUM($0)], SUM_NO_MATCH2=[SUM($1)], 
SUM_NO_MATCH3=[SUM($2)])
+  LogicalProject($f0=[CASE(=($7, -1), 1, 0)], $f1=[CASE(=($7, -1), 2, 0)], 
$f2=[CASE(=($7, -1), 3, -1)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
     </Resource>
   </TestCase>
diff --git a/core/src/test/resources/sql/agg.iq 
b/core/src/test/resources/sql/agg.iq
index 89b7dc15e8..8541f46ba5 100644
--- a/core/src/test/resources/sql/agg.iq
+++ b/core/src/test/resources/sql/agg.iq
@@ -1679,6 +1679,61 @@ from (values 0, null, 0, 1) as t(x);
 
 !ok
 
+# Convert CASE to FILTER with no input rows
+select  COALESCE(sum(case when x = 1 then 1 else 0 end),0) as a,
+        COALESCE(sum(case when x = 1 then 2 else 0 end),0) as b,
+        COALESCE(sum(case when x = 1 then 3 else -1 end),0) as c
+from (values 0, null, 0, 2) as t(x) where x*x=1;
++---+---+---+
+| A | B | C |
++---+---+---+
+| 0 | 0 | 0 |
++---+---+---+
+(1 row)
+
+!ok
+EnumerableCalc(expr#0..2=[{inputs}], expr#3=[CAST($t0):INTEGER NOT NULL], 
$f0=[$t3], $f1=[$t1], $f2=[$t2])
+  EnumerableAggregate(group=[{}], agg#0=[COUNT() FILTER $1], agg#1=[$SUM0($2) 
FILTER $3], agg#2=[$SUM0($0)])
+    EnumerableCalc(expr#0=[{inputs}], expr#1=[1], expr#2=[=($t0, $t1)], 
expr#3=[3], expr#4=[-1], expr#5=[CASE($t2, $t3, $t4)], expr#6=[IS TRUE($t2)], 
expr#7=[2], expr#8=[*($t0, $t0)], expr#9=[=($t8, $t1)], $f2=[$t5], $f3=[$t6], 
$f4=[$t7], $f5=[$t6], $condition=[$t9])
+      EnumerableValues(tuples=[[{ 0 }, { null }, { 0 }, { 2 }]])
+!plan
+
+# Convert CASE to FILTER with no input rows
+select  sum(case when x = 1 then 1 else 0 end) as a,
+        sum(case when x = 1 then 2 else 0 end) as b,
+        sum(case when x = 1 then 3 else -1 end) as c
+from (values 0, null, 0, 2) as t(x) where x*x=1;
++---+---+---+
+| A | B | C |
++---+---+---+
+|   |   |   |
++---+---+---+
+(1 row)
+
+!ok
+EnumerableAggregate(group=[{}], A=[SUM($0)], B=[SUM($1)], C=[SUM($2)])
+  EnumerableCalc(expr#0=[{inputs}], expr#1=[1], expr#2=[=($t0, $t1)], 
expr#3=[0], expr#4=[CASE($t2, $t1, $t3)], expr#5=[2], expr#6=[CASE($t2, $t5, 
$t3)], expr#7=[3], expr#8=[-1], expr#9=[CASE($t2, $t7, $t8)], expr#10=[*($t0, 
$t0)], expr#11=[=($t10, $t1)], $f0=[$t4], $f1=[$t6], $f2=[$t9], 
$condition=[$t11])
+    EnumerableValues(tuples=[[{ 0 }, { null }, { 0 }, { 2 }]])
+!plan
+
+# Convert CASE to FILTER without matches
+select  sum(case when x = 1 then 1 else 0 end) as a,
+        sum(case when x = 1 then 2 else 0 end) as b,
+        sum(case when x = 1 then 3 else -1 end) as c
+from (values 0, null, 0, 2) as t(x);
++---+---+----+
+| A | B | C  |
++---+---+----+
+| 0 | 0 | -4 |
++---+---+----+
+(1 row)
+
+!ok
+EnumerableAggregate(group=[{}], A=[SUM($0)], B=[SUM($1)], C=[SUM($2)])
+  EnumerableCalc(expr#0=[{inputs}], expr#1=[1], expr#2=[=($t0, $t1)], 
expr#3=[0], expr#4=[CASE($t2, $t1, $t3)], expr#5=[2], expr#6=[CASE($t2, $t5, 
$t3)], expr#7=[3], expr#8=[-1], expr#9=[CASE($t2, $t7, $t8)], $f0=[$t4], 
$f1=[$t6], $f2=[$t9])
+    EnumerableValues(tuples=[[{ 0 }, { null }, { 0 }, { 2 }]])
+!plan
+
 # Same, expressed as FILTER
 select count(*) filter (where (x = 0) is not true) as c
 from (values 0, null, 0, 1) as t(x);

Reply via email to