This is an automated email from the ASF dual-hosted git repository.
zhenchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new d63996d487 [CALCITE-7086] Implement a rule that performs the inverse
operation of AggregateCaseToFilterRule
d63996d487 is described below
commit d63996d487d4395cefc65157ef4a5596741bfbcd
Author: Silun Dong <[email protected]>
AuthorDate: Fri Jul 18 13:58:22 2025 +0800
[CALCITE-7086] Implement a rule that performs the inverse operation of
AggregateCaseToFilterRule
---
.../rel/rules/AggregateFilterToCaseRule.java | 156 +++++++++++++++++++++
.../org/apache/calcite/rel/rules/CoreRules.java | 4 +
.../calcite/rel/rel2sql/RelToSqlConverterTest.java | 21 +++
.../org/apache/calcite/test/RelOptRulesTest.java | 23 +++
.../org/apache/calcite/test/RelOptRulesTest.xml | 29 ++++
core/src/test/resources/sql/planner.iq | 32 +++++
6 files changed, 265 insertions(+)
diff --git
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateFilterToCaseRule.java
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateFilterToCaseRule.java
new file mode 100644
index 0000000000..d76a3be0d8
--- /dev/null
+++
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateFilterToCaseRule.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.rel.rules;
+
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.tools.RelBuilder;
+
+import com.google.common.collect.ImmutableList;
+
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Rule that converts true filtered aggregates into CASE-style filtered
aggregates.
+ *
+ * <p>For example,
+ *
+ * <blockquote>
+ * <code>SELECT SUM(salary) FILTER (WHERE gender = 'F')<br>
+ * FROM Emp</code>
+ * </blockquote>
+ *
+ * <p>becomes
+ *
+ * <blockquote>
+ * <code>SELECT SUM(CASE WHEN gender = 'F' THEN salary END)<br>
+ * FROM Emp</code>
+ * </blockquote>
+ *
+ * @see CoreRules#AGGREGATE_FILTER_TO_CASE
+ */
[email protected]
+public class AggregateFilterToCaseRule
+ extends RelRule<AggregateFilterToCaseRule.Config>
+ implements TransformationRule {
+
+ /** Creates an AggregateFilterToCaseRule. */
+ protected AggregateFilterToCaseRule(Config config) {
+ super(config);
+ }
+
+ @Override public boolean matches(RelOptRuleCall call) {
+ final Aggregate aggregate = call.rel(0);
+ for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+ if (aggregateCall.hasFilter()) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override public void onMatch(RelOptRuleCall call) {
+ final RelBuilder relBuilder = call.builder();
+ final Aggregate aggregate = call.rel(0);
+ final Project project = call.rel(1);
+ final List<AggregateCall> newCalls =
+ new ArrayList<>(aggregate.getAggCallList().size());
+ final List<RexNode> newProjects = new ArrayList<>(project.getProjects());
+
+ for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+ AggregateCall newCall =
+ transform(
+ aggregateCall,
+ relBuilder.getRexBuilder(),
+ relBuilder.getTypeFactory(),
+ newProjects);
+ newCalls.add(newCall);
+ }
+
+ if (newCalls.equals(aggregate.getAggCallList())) {
+ return;
+ }
+
+ relBuilder
+ .push(project.getInput())
+ .project(newProjects);
+ final RelBuilder.GroupKey groupKey =
+ relBuilder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets());
+ relBuilder.aggregate(groupKey, newCalls);
+ call.transformTo(relBuilder.build());
+ }
+
+ private static AggregateCall transform(AggregateCall call,
+ RexBuilder rexBuilder, RelDataTypeFactory typeFactory, List<RexNode>
newProjects) {
+ if (!call.hasFilter()) {
+ return call;
+ }
+ final SqlKind kind = call.getAggregation().getKind();
+ final RexNode condition = newProjects.get(call.filterArg);
+ final RexNode arg1;
+ final RexNode arg2;
+
+ if (kind == SqlKind.COUNT) {
+ // COUNT function may have no argument. When building the CASE
expression,
+ // fill arg1 with "dummy":
+ // COUNT() FILTER (x = 'foo') => COUNT(CASE WHEN x = 'foo' THEN 0 END)
+ arg1 = call.getArgList().size() == 0
+ ?
rexBuilder.makeZeroLiteral(typeFactory.createSqlType(SqlTypeName.INTEGER))
+ : newProjects.get(call.getArgList().get(0));
+ } else if (call.isDistinct() || call.getArgList().size() != 1) {
+ // ensure the reversibility of transformation, refer to
AggregateCaseToFilterRule,
+ // when the aggregate function is with distinct, only the COUNT can be
converted.
+ return call;
+ } else {
+ arg1 = newProjects.get(call.getArgList().get(0));
+ }
+
+ arg2 = rexBuilder.makeNullLiteral(arg1.getType());
+ final RexNode caseWhen = rexBuilder.makeCall(SqlStdOperatorTable.CASE,
condition, arg1, arg2);
+ newProjects.add(caseWhen);
+ return AggregateCall.create(call.getParserPosition(),
call.getAggregation(), call.isDistinct(),
+ call.isApproximate(), call.ignoreNulls(), call.rexList,
+ ImmutableList.of(newProjects.size() - 1), -1, call.distinctKeys,
RelCollations.EMPTY,
+ call.getType(), call.getName());
+ }
+
+ /** Rule configuration. */
+ @Value.Immutable
+ public interface Config extends RelRule.Config {
+ Config DEFAULT = ImmutableAggregateFilterToCaseRule.Config.of()
+ .withOperandSupplier(b0 ->
+ b0.operand(Aggregate.class).oneInput(b1 ->
+ b1.operand(Project.class).anyInputs()));
+
+ @Override default AggregateFilterToCaseRule toRule() {
+ return new AggregateFilterToCaseRule(this);
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
index 4b0d5127e9..ac8c8729b4 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
@@ -957,4 +957,8 @@ private CoreRules() {}
/** Rule that convert FULL JOIN to LEFT JOIN and RIGHT JOIN. */
public static final FullToLeftAndRightJoinRule FULL_TO_LEFT_AND_RIGHT_JOIN =
FullToLeftAndRightJoinRule.Config.DEFAULT.toRule();
+
+ /** Rule that converts true filtered aggregates into CASE-style filtered
aggregates. */
+ public static final AggregateFilterToCaseRule AGGREGATE_FILTER_TO_CASE =
+ AggregateFilterToCaseRule.Config.DEFAULT.toRule();
}
diff --git
a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
index 9476302cd8..8be953d2ba 100644
---
a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
+++
b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
@@ -10432,6 +10432,27 @@ private void checkLiteral2(String expression, String
expected) {
.ok(expected);
}
+ @Test void testAggregateFilterToCase() {
+ final String query = "select\n"
+ + " sum(sal) filter(where deptno = 10) as sum_match,\n"
+ + " count(distinct deptno) filter(where job = 'CLERK') as
count_distinct_match,\n"
+ + " count(*) filter(where deptno = 40) as count_star_match\n"
+ + " from emp";
+ final String expected = "SELECT"
+ + " SUM(CASE WHEN CAST(\"DEPTNO\" AS INTEGER) = 10 THEN \"SAL\" ELSE
NULL END)"
+ + " AS \"SUM_MATCH\","
+ + " COUNT(DISTINCT CASE WHEN \"JOB\" = 'CLERK' THEN \"DEPTNO\" ELSE
NULL END)"
+ + " AS \"COUNT_DISTINCT_MATCH\","
+ + " COUNT(CASE WHEN CAST(\"DEPTNO\" AS INTEGER) = 40 THEN 0 ELSE NULL
END)"
+ + " AS \"COUNT_STAR_MATCH\"\nFROM \"SCOTT\".\"EMP\"";
+
+ sql(query)
+ .schema(CalciteAssert.SchemaSpec.JDBC_SCOTT)
+ .withCalcite()
+ .optimize(RuleSets.ofList(CoreRules.AGGREGATE_FILTER_TO_CASE), null)
+ .ok(expected);
+ }
+
/** Fluid interface to run tests. */
static class Sql {
private final CalciteAssert.SchemaSpec schemaSpec;
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 5e54845c81..a48bf6907a 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -11087,4 +11087,27 @@ private void
checkLoptOptimizeJoinRule(LoptOptimizeJoinRule rule) {
.withRule(CoreRules.FULL_TO_LEFT_AND_RIGHT_JOIN)
.checkUnchanged();
}
+
+ /** Test case of
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-7086">[CALCITE-7086]
+ * Implement a rule that performs the inverse operation of
AggregateCaseToFilterRule</a>. */
+ @Test void testAggregateFilterToCase() {
+ final String sql = "select coalesce(sum_match, 0) as sum0_match,
sum_distinct_not_match,\n"
+ + " count_distinct_match, count_star_match from (\n"
+ + " select\n"
+ + " sum(sal) filter(where deptno = 10) as sum_match,\n"
+ + " sum(distinct empno) filter(where deptno = 20) as
sum_distinct_not_match,\n"
+ + " count(distinct deptno) filter(where job = 'CLERK') as
count_distinct_match,\n"
+ + " count(*) filter(where deptno = 40) as count_star_match\n"
+ + " from emp\n"
+ + " )";
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(CoreRules.PROJECT_AGGREGATE_MERGE)
+ .build();
+
+ sql(sql)
+ .withPre(program)
+ .withRule(CoreRules.AGGREGATE_FILTER_TO_CASE)
+ .check();
+ }
}
diff --git
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index c61d39b755..03278394ab 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -288,6 +288,35 @@ LogicalAggregate(group=[{0, 7}], groups=[[{0, 7}, {0},
{7}]], EXPR$2=[SUM($0)])
LogicalAggregate(group=[{0, 1}], groups=[[{0, 1}, {0}, {1}]], EXPR$2=[SUM($0)])
LogicalProject(EMPNO=[$0], DEPTNO=[$7])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testAggregateFilterToCase">
+ <Resource name="sql">
+ <![CDATA[select coalesce(sum_match, 0) as sum0_match,
sum_distinct_not_match,
+ count_distinct_match, count_star_match from (
+ select
+ sum(sal) filter(where deptno = 10) as sum_match,
+ sum(distinct empno) filter(where deptno = 20) as sum_distinct_not_match,
+ count(distinct deptno) filter(where job = 'CLERK') as count_distinct_match,
+ count(*) filter(where deptno = 40) as count_star_match
+ from emp
+ )]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(SUM0_MATCH=[$3], SUM_DISTINCT_NOT_MATCH=[$0],
COUNT_DISTINCT_MATCH=[$1], COUNT_STAR_MATCH=[$2])
+ LogicalAggregate(group=[{}], SUM_DISTINCT_NOT_MATCH=[SUM(DISTINCT $2) FILTER
$3], COUNT_DISTINCT_MATCH=[COUNT(DISTINCT $4) FILTER $5],
COUNT_STAR_MATCH=[COUNT() FILTER $6], agg#3=[$SUM0($0) FILTER $1])
+ LogicalProject(SAL=[$5], $f1=[=($7, 10)], EMPNO=[$0], $f3=[=($7, 20)],
DEPTNO=[$7], $f5=[=($2, 'CLERK')], $f6=[=($7, 40)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(SUM0_MATCH=[$3], SUM_DISTINCT_NOT_MATCH=[$0],
COUNT_DISTINCT_MATCH=[$1], COUNT_STAR_MATCH=[$2])
+ LogicalAggregate(group=[{}], SUM_DISTINCT_NOT_MATCH=[SUM(DISTINCT $0) FILTER
$1], COUNT_DISTINCT_MATCH=[COUNT(DISTINCT $2)], COUNT_STAR_MATCH=[COUNT($3)],
agg#3=[$SUM0($4)])
+ LogicalProject(EMPNO=[$0], $f3=[=($7, 20)], $f7=[CASE(=($2, 'CLERK'), $7,
null:INTEGER)], $f8=[CASE(=($7, 40), 0, null:INTEGER)], $f9=[CASE(=($7, 10),
$5, null:INTEGER)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
diff --git a/core/src/test/resources/sql/planner.iq
b/core/src/test/resources/sql/planner.iq
index 47231cae12..844962d0ad 100644
--- a/core/src/test/resources/sql/planner.iq
+++ b/core/src/test/resources/sql/planner.iq
@@ -319,4 +319,36 @@ EnumerableHashJoin(condition=[AND(=($0, $6), OR(AND(>($1,
11), <=($7, 32)), AND(
!plan
!set planner-rules original
+# [CALCITE-7086] Implement a rule that performs the inverse operation of
AggregateCaseToFilterRule
+# Refer to RelToSqlConverterTest.testAggregateFilterToCase(). The following
two SQL
+# represent the true filtered Aggregate and the case-style Aggregate converted
by AggregateFilterToCaseRule.
+!use scott
+select
+ sum(sal) filter(where deptno = 10) as sum_match,
+ count(distinct deptno) filter(where job = 'CLERK') as count_distinct_match,
+ count(*) filter(where deptno = 40) as count_star_match
+ from emp;
++-----------+----------------------+------------------+
+| SUM_MATCH | COUNT_DISTINCT_MATCH | COUNT_STAR_MATCH |
++-----------+----------------------+------------------+
+| 8750.00 | 3 | 0 |
++-----------+----------------------+------------------+
+(1 row)
+
+!ok
+
+SELECT
+SUM(CASE WHEN CAST("DEPTNO" AS INTEGER) = 10 THEN "SAL" ELSE NULL END) AS
"SUM_MATCH",
+COUNT(DISTINCT CASE WHEN "JOB" = 'CLERK' THEN "DEPTNO" ELSE NULL END) AS
"COUNT_DISTINCT_MATCH",
+COUNT(CASE WHEN CAST("DEPTNO" AS INTEGER) = 40 THEN 0 ELSE NULL END) AS
"COUNT_STAR_MATCH"
+FROM "scott"."EMP";
++-----------+----------------------+------------------+
+| SUM_MATCH | COUNT_DISTINCT_MATCH | COUNT_STAR_MATCH |
++-----------+----------------------+------------------+
+| 8750.00 | 3 | 0 |
++-----------+----------------------+------------------+
+(1 row)
+
+!ok
+
# End planner.iq