This is an automated email from the ASF dual-hosted git repository.

zhenchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new d63996d487 [CALCITE-7086] Implement a rule that performs the inverse 
operation of AggregateCaseToFilterRule
d63996d487 is described below

commit d63996d487d4395cefc65157ef4a5596741bfbcd
Author: Silun Dong <[email protected]>
AuthorDate: Fri Jul 18 13:58:22 2025 +0800

    [CALCITE-7086] Implement a rule that performs the inverse operation of 
AggregateCaseToFilterRule
---
 .../rel/rules/AggregateFilterToCaseRule.java       | 156 +++++++++++++++++++++
 .../org/apache/calcite/rel/rules/CoreRules.java    |   4 +
 .../calcite/rel/rel2sql/RelToSqlConverterTest.java |  21 +++
 .../org/apache/calcite/test/RelOptRulesTest.java   |  23 +++
 .../org/apache/calcite/test/RelOptRulesTest.xml    |  29 ++++
 core/src/test/resources/sql/planner.iq             |  32 +++++
 6 files changed, 265 insertions(+)

diff --git 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateFilterToCaseRule.java
 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateFilterToCaseRule.java
new file mode 100644
index 0000000000..d76a3be0d8
--- /dev/null
+++ 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateFilterToCaseRule.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.rel.rules;
+
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.tools.RelBuilder;
+
+import com.google.common.collect.ImmutableList;
+
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Rule that converts true filtered aggregates into CASE-style filtered 
aggregates.
+ *
+ * <p>For example,
+ *
+ * <blockquote>
+ *   <code>SELECT SUM(salary) FILTER (WHERE gender = 'F')<br>
+ *   FROM Emp</code>
+ * </blockquote>
+ *
+ * <p>becomes
+ *
+ * <blockquote>
+ *   <code>SELECT SUM(CASE WHEN gender = 'F' THEN salary END)<br>
+ *   FROM Emp</code>
+ * </blockquote>
+ *
+ * @see CoreRules#AGGREGATE_FILTER_TO_CASE
+ */
[email protected]
+public class AggregateFilterToCaseRule
+    extends RelRule<AggregateFilterToCaseRule.Config>
+    implements TransformationRule {
+
+  /** Creates an AggregateFilterToCaseRule. */
+  protected AggregateFilterToCaseRule(Config config) {
+    super(config);
+  }
+
+  @Override public boolean matches(RelOptRuleCall call) {
+    final Aggregate aggregate = call.rel(0);
+    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+      if (aggregateCall.hasFilter()) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  @Override public void onMatch(RelOptRuleCall call) {
+    final RelBuilder relBuilder = call.builder();
+    final Aggregate aggregate = call.rel(0);
+    final Project project = call.rel(1);
+    final List<AggregateCall> newCalls =
+        new ArrayList<>(aggregate.getAggCallList().size());
+    final List<RexNode> newProjects = new ArrayList<>(project.getProjects());
+
+    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+      AggregateCall newCall =
+          transform(
+              aggregateCall,
+              relBuilder.getRexBuilder(),
+              relBuilder.getTypeFactory(),
+              newProjects);
+      newCalls.add(newCall);
+    }
+
+    if (newCalls.equals(aggregate.getAggCallList())) {
+      return;
+    }
+
+    relBuilder
+        .push(project.getInput())
+        .project(newProjects);
+    final RelBuilder.GroupKey groupKey =
+        relBuilder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets());
+    relBuilder.aggregate(groupKey, newCalls);
+    call.transformTo(relBuilder.build());
+  }
+
+  private static AggregateCall transform(AggregateCall call,
+      RexBuilder rexBuilder, RelDataTypeFactory typeFactory, List<RexNode> 
newProjects) {
+    if (!call.hasFilter()) {
+      return call;
+    }
+    final SqlKind kind = call.getAggregation().getKind();
+    final RexNode condition = newProjects.get(call.filterArg);
+    final RexNode arg1;
+    final RexNode arg2;
+
+    if (kind == SqlKind.COUNT) {
+      // COUNT function may have no argument. When building the CASE 
expression,
+      // fill arg1 with "dummy":
+      // COUNT() FILTER (x = 'foo') => COUNT(CASE WHEN x = 'foo' THEN 0 END)
+      arg1 = call.getArgList().size() == 0
+          ? 
rexBuilder.makeZeroLiteral(typeFactory.createSqlType(SqlTypeName.INTEGER))
+          : newProjects.get(call.getArgList().get(0));
+    } else if (call.isDistinct() || call.getArgList().size() != 1) {
+      // ensure the reversibility of transformation, refer to 
AggregateCaseToFilterRule,
+      // when the aggregate function is with distinct, only the COUNT can be 
converted.
+      return call;
+    } else {
+      arg1 = newProjects.get(call.getArgList().get(0));
+    }
+
+    arg2 = rexBuilder.makeNullLiteral(arg1.getType());
+    final RexNode caseWhen = rexBuilder.makeCall(SqlStdOperatorTable.CASE, 
condition, arg1, arg2);
+    newProjects.add(caseWhen);
+    return AggregateCall.create(call.getParserPosition(), 
call.getAggregation(), call.isDistinct(),
+        call.isApproximate(), call.ignoreNulls(), call.rexList,
+        ImmutableList.of(newProjects.size() - 1), -1, call.distinctKeys, 
RelCollations.EMPTY,
+        call.getType(), call.getName());
+  }
+
+  /** Rule configuration. */
+  @Value.Immutable
+  public interface Config extends RelRule.Config {
+    Config DEFAULT = ImmutableAggregateFilterToCaseRule.Config.of()
+        .withOperandSupplier(b0 ->
+            b0.operand(Aggregate.class).oneInput(b1 ->
+                b1.operand(Project.class).anyInputs()));
+
+    @Override default AggregateFilterToCaseRule toRule() {
+      return new AggregateFilterToCaseRule(this);
+    }
+  }
+}
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java 
b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
index 4b0d5127e9..ac8c8729b4 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
@@ -957,4 +957,8 @@ private CoreRules() {}
   /** Rule that convert FULL JOIN to LEFT JOIN and RIGHT JOIN. */
   public static final FullToLeftAndRightJoinRule FULL_TO_LEFT_AND_RIGHT_JOIN =
       FullToLeftAndRightJoinRule.Config.DEFAULT.toRule();
+
+  /** Rule that converts true filtered aggregates into CASE-style filtered 
aggregates. */
+  public static final AggregateFilterToCaseRule AGGREGATE_FILTER_TO_CASE =
+      AggregateFilterToCaseRule.Config.DEFAULT.toRule();
 }
diff --git 
a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java 
b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
index 9476302cd8..8be953d2ba 100644
--- 
a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
+++ 
b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
@@ -10432,6 +10432,27 @@ private void checkLiteral2(String expression, String 
expected) {
         .ok(expected);
   }
 
+  @Test void testAggregateFilterToCase() {
+    final String query = "select\n"
+        + " sum(sal) filter(where deptno = 10) as sum_match,\n"
+        + " count(distinct deptno) filter(where job = 'CLERK') as 
count_distinct_match,\n"
+        + " count(*) filter(where deptno = 40) as count_star_match\n"
+        + " from emp";
+    final String expected = "SELECT"
+        + " SUM(CASE WHEN CAST(\"DEPTNO\" AS INTEGER) = 10 THEN \"SAL\" ELSE 
NULL END)"
+        + " AS \"SUM_MATCH\","
+        + " COUNT(DISTINCT CASE WHEN \"JOB\" = 'CLERK' THEN \"DEPTNO\" ELSE 
NULL END)"
+        + " AS \"COUNT_DISTINCT_MATCH\","
+        + " COUNT(CASE WHEN CAST(\"DEPTNO\" AS INTEGER) = 40 THEN 0 ELSE NULL 
END)"
+        + " AS \"COUNT_STAR_MATCH\"\nFROM \"SCOTT\".\"EMP\"";
+
+    sql(query)
+        .schema(CalciteAssert.SchemaSpec.JDBC_SCOTT)
+        .withCalcite()
+        .optimize(RuleSets.ofList(CoreRules.AGGREGATE_FILTER_TO_CASE), null)
+        .ok(expected);
+  }
+
   /** Fluid interface to run tests. */
   static class Sql {
     private final CalciteAssert.SchemaSpec schemaSpec;
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java 
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 5e54845c81..a48bf6907a 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -11087,4 +11087,27 @@ private void 
checkLoptOptimizeJoinRule(LoptOptimizeJoinRule rule) {
         .withRule(CoreRules.FULL_TO_LEFT_AND_RIGHT_JOIN)
         .checkUnchanged();
   }
+
+  /** Test case of
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-7086";>[CALCITE-7086]
+   * Implement a rule that performs the inverse operation of 
AggregateCaseToFilterRule</a>. */
+  @Test void testAggregateFilterToCase() {
+    final String sql = "select coalesce(sum_match, 0) as sum0_match, 
sum_distinct_not_match,\n"
+        + " count_distinct_match, count_star_match from (\n"
+        + " select\n"
+        + " sum(sal) filter(where deptno = 10) as sum_match,\n"
+        + " sum(distinct empno) filter(where deptno = 20) as 
sum_distinct_not_match,\n"
+        + " count(distinct deptno) filter(where job = 'CLERK') as 
count_distinct_match,\n"
+        + " count(*) filter(where deptno = 40) as count_star_match\n"
+        + " from emp\n"
+        + " )";
+    HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(CoreRules.PROJECT_AGGREGATE_MERGE)
+        .build();
+
+    sql(sql)
+        .withPre(program)
+        .withRule(CoreRules.AGGREGATE_FILTER_TO_CASE)
+        .check();
+  }
 }
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index c61d39b755..03278394ab 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -288,6 +288,35 @@ LogicalAggregate(group=[{0, 7}], groups=[[{0, 7}, {0}, 
{7}]], EXPR$2=[SUM($0)])
 LogicalAggregate(group=[{0, 1}], groups=[[{0, 1}, {0}, {1}]], EXPR$2=[SUM($0)])
   LogicalProject(EMPNO=[$0], DEPTNO=[$7])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testAggregateFilterToCase">
+    <Resource name="sql">
+      <![CDATA[select coalesce(sum_match, 0) as sum0_match, 
sum_distinct_not_match,
+ count_distinct_match, count_star_match from (
+ select
+ sum(sal) filter(where deptno = 10) as sum_match,
+ sum(distinct empno) filter(where deptno = 20) as sum_distinct_not_match,
+ count(distinct deptno) filter(where job = 'CLERK') as count_distinct_match,
+ count(*) filter(where deptno = 40) as count_star_match
+ from emp
+ )]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalProject(SUM0_MATCH=[$3], SUM_DISTINCT_NOT_MATCH=[$0], 
COUNT_DISTINCT_MATCH=[$1], COUNT_STAR_MATCH=[$2])
+  LogicalAggregate(group=[{}], SUM_DISTINCT_NOT_MATCH=[SUM(DISTINCT $2) FILTER 
$3], COUNT_DISTINCT_MATCH=[COUNT(DISTINCT $4) FILTER $5], 
COUNT_STAR_MATCH=[COUNT() FILTER $6], agg#3=[$SUM0($0) FILTER $1])
+    LogicalProject(SAL=[$5], $f1=[=($7, 10)], EMPNO=[$0], $f3=[=($7, 20)], 
DEPTNO=[$7], $f5=[=($2, 'CLERK')], $f6=[=($7, 40)])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(SUM0_MATCH=[$3], SUM_DISTINCT_NOT_MATCH=[$0], 
COUNT_DISTINCT_MATCH=[$1], COUNT_STAR_MATCH=[$2])
+  LogicalAggregate(group=[{}], SUM_DISTINCT_NOT_MATCH=[SUM(DISTINCT $0) FILTER 
$1], COUNT_DISTINCT_MATCH=[COUNT(DISTINCT $2)], COUNT_STAR_MATCH=[COUNT($3)], 
agg#3=[$SUM0($4)])
+    LogicalProject(EMPNO=[$0], $f3=[=($7, 20)], $f7=[CASE(=($2, 'CLERK'), $7, 
null:INTEGER)], $f8=[CASE(=($7, 40), 0, null:INTEGER)], $f9=[CASE(=($7, 10), 
$5, null:INTEGER)])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
     </Resource>
   </TestCase>
diff --git a/core/src/test/resources/sql/planner.iq 
b/core/src/test/resources/sql/planner.iq
index 47231cae12..844962d0ad 100644
--- a/core/src/test/resources/sql/planner.iq
+++ b/core/src/test/resources/sql/planner.iq
@@ -319,4 +319,36 @@ EnumerableHashJoin(condition=[AND(=($0, $6), OR(AND(>($1, 
11), <=($7, 32)), AND(
 !plan
 !set planner-rules original
 
+# [CALCITE-7086] Implement a rule that performs the inverse operation of 
AggregateCaseToFilterRule
+# Refer to RelToSqlConverterTest.testAggregateFilterToCase(). The following 
two SQL
+# represent the true filtered Aggregate and the case-style Aggregate converted 
by AggregateFilterToCaseRule.
+!use scott
+select
+ sum(sal) filter(where deptno = 10) as sum_match,
+ count(distinct deptno) filter(where job = 'CLERK') as count_distinct_match,
+ count(*) filter(where deptno = 40) as count_star_match
+ from emp;
++-----------+----------------------+------------------+
+| SUM_MATCH | COUNT_DISTINCT_MATCH | COUNT_STAR_MATCH |
++-----------+----------------------+------------------+
+|   8750.00 |                    3 |                0 |
++-----------+----------------------+------------------+
+(1 row)
+
+!ok
+
+SELECT
+SUM(CASE WHEN CAST("DEPTNO" AS INTEGER) = 10 THEN "SAL" ELSE NULL END) AS 
"SUM_MATCH",
+COUNT(DISTINCT CASE WHEN "JOB" = 'CLERK' THEN "DEPTNO" ELSE NULL END) AS 
"COUNT_DISTINCT_MATCH",
+COUNT(CASE WHEN CAST("DEPTNO" AS INTEGER) = 40 THEN 0 ELSE NULL END) AS 
"COUNT_STAR_MATCH"
+FROM "scott"."EMP";
++-----------+----------------------+------------------+
+| SUM_MATCH | COUNT_DISTINCT_MATCH | COUNT_STAR_MATCH |
++-----------+----------------------+------------------+
+|   8750.00 |                    3 |                0 |
++-----------+----------------------+------------------+
+(1 row)
+
+!ok
+
 # End planner.iq

Reply via email to