This is an automated email from the ASF dual-hosted git repository.

zabetak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 7654dd8216 [CALCITE-7362] Add rule to transform WHERE clauses into 
filtered aggregates
7654dd8216 is described below

commit 7654dd82162508231cb3ed44e60fed356e931403
Author: Stamatis Zampetakis <[email protected]>
AuthorDate: Thu Feb 19 11:30:01 2026 +0100

    [CALCITE-7362] Add rule to transform WHERE clauses into filtered aggregates
---
 .../AggregateFilterToFilteredAggregateRule.java    | 105 +++++++++++++
 .../org/apache/calcite/rel/rules/CoreRules.java    |   5 +
 ...AggregateFilterToFilteredAggregateRuleTest.java | 120 ++++++++++++++
 .../AggregateFilterToFilteredAggregateRuleTest.xml | 173 +++++++++++++++++++++
 4 files changed, 403 insertions(+)

diff --git 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateFilterToFilteredAggregateRule.java
 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateFilterToFilteredAggregateRule.java
new file mode 100644
index 0000000000..dd82f69762
--- /dev/null
+++ 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateFilterToFilteredAggregateRule.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.rel.rules;
+
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Filter;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.tools.RelBuilder;
+
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Rule that converts an aggregate on top of a filter into a filtered 
aggregate.
+ *
+ * <p>Before
+ * <pre><code>
+ *   SELECT SUM(salary)
+ *   FROM Emp
+ *   WHERE deptno = 10
+ *  </code></pre>
+ *
+ * <p>After
+ * <pre><code>
+ *   SELECT SUM(salary) FILTER (WHERE deptno = 10)
+ *   FROM Emp
+ *  </code></pre>
+ *
+ * <p>The transformation is particularly useful in view-based rewriting.
+ * The removal of the {@code Filter} operators lifts some restrictions when 
using
+ * the {@link org.apache.calcite.rel.rules.materialize.MaterializedViewRules}.
+ *
+ * <p>Filtered aggregates can be transformed to other equivalent forms via 
other
+ * transformation rules (e.g., {@link AggregateFilterToCaseRule}).
+ */
[email protected] public class AggregateFilterToFilteredAggregateRule
+    extends RelRule<AggregateFilterToFilteredAggregateRule.Config> {
+
+  private AggregateFilterToFilteredAggregateRule(Config config) {
+    super(config);
+  }
+
+  @Override public void onMatch(RelOptRuleCall call) {
+    Aggregate aggregate = call.rel(0);
+    Filter filter = call.rel(1);
+    if (!aggregate.getGroupSet().isEmpty()) {
+      // At the moment we only support the transformation for grand totals, 
i.e.,
+      // aggregates with no grouping keys.
+      return;
+    }
+    RelBuilder builder = call.builder();
+    builder.push(filter.getInput());
+    List<RexNode> projects = new ArrayList<>(builder.fields());
+    List<AggregateCall> newAggCalls = new ArrayList<>();
+    for (AggregateCall aggCall : aggregate.getAggCallList()) {
+      if (!aggCall.getAggregation().allowsFilter()) {
+        return;
+      }
+      RexNode condition = filter.getCondition();
+      // If the aggregate call has its own filter, combine it with the filter 
condition.
+      if (aggCall.hasFilter()) {
+        condition = builder.and(condition, builder.field(aggCall.filterArg));
+      }
+      int pos = projects.indexOf(condition);
+      if (pos < 0) {
+        pos = projects.size();
+        projects.add(condition);
+      }
+      newAggCalls.add(aggCall.withFilter(pos));
+    }
+    builder.project(projects);
+    builder.aggregate(builder.groupKey(), newAggCalls);
+    call.transformTo(builder.build());
+  }
+
+  /** Rule configuration. */
+  @Value.Immutable public interface Config extends RelRule.Config {
+    Config DEFAULT = 
ImmutableAggregateFilterToFilteredAggregateRule.Config.of()
+        .withOperandSupplier(
+            a -> a.operand(Aggregate.class).oneInput(f -> 
f.operand(Filter.class).anyInputs()));
+
+    @Override default AggregateFilterToFilteredAggregateRule toRule() {
+      return new AggregateFilterToFilteredAggregateRule(this);
+    }
+  }
+}
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java 
b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
index 21d5c971d6..444c89ccdd 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
@@ -959,6 +959,11 @@ private CoreRules() {}
   public static final AggregateFilterToCaseRule AGGREGATE_FILTER_TO_CASE =
       AggregateFilterToCaseRule.Config.DEFAULT.toRule();
 
+  /** Rule that converts an aggregate on of a filter into a filtered 
aggregate. */
+  public static final AggregateFilterToFilteredAggregateRule
+      AGGREGATE_FILTER_TO_FILTERED_AGGREGATE =
+      AggregateFilterToFilteredAggregateRule.Config.DEFAULT.toRule();
+
   /** Rule that remove duplicate {@link Sort} keys. */
   public static final SortRemoveDuplicateKeysRule SORT_REMOVE_DUPLICATE_KEYS =
       SortRemoveDuplicateKeysRule.Config.DEFAULT.toRule();
diff --git 
a/core/src/test/java/org/apache/calcite/test/AggregateFilterToFilteredAggregateRuleTest.java
 
b/core/src/test/java/org/apache/calcite/test/AggregateFilterToFilteredAggregateRuleTest.java
new file mode 100644
index 0000000000..9e6cd3ea33
--- /dev/null
+++ 
b/core/src/test/java/org/apache/calcite/test/AggregateFilterToFilteredAggregateRuleTest.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.test;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.hep.HepProgram;
+import org.apache.calcite.rel.rules.AggregateFilterToFilteredAggregateRule;
+import org.apache.calcite.rel.rules.CoreRules;
+
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static 
org.apache.calcite.rel.rules.CoreRules.AGGREGATE_FILTER_TO_FILTERED_AGGREGATE;
+import static org.apache.calcite.rel.rules.CoreRules.AGGREGATE_PROJECT_MERGE;
+import static 
org.apache.calcite.rel.rules.CoreRules.PROJECT_FILTER_TRANSPOSE_WHOLE_PROJECT_EXPRESSIONS;
+
+/**
+ * Unit tests for {@link AggregateFilterToFilteredAggregateRule}.
+ *
+ * <p>Relevant tickets:
+ * <ul>
+ * <li><a href="https://issues.apache.org/jira/browse/CALCITE-7362";>
+ * [CALCITE-7362] Add rule to transform WHERE clauses into filtered aggregates
+ * </a></li>
+ * </ul>
+ */
+class AggregateFilterToFilteredAggregateRuleTest {
+
+  private static RelOptFixture fixture() {
+    return RelOptFixture.DEFAULT.withDiffRepos(
+        
DiffRepository.lookup(AggregateFilterToFilteredAggregateRuleTest.class));
+  }
+
+  private static RelOptFixture sql(String sql) {
+    return fixture().sql(sql);
+  }
+
+  @Test void testSingleColumnAggregate() {
+    String sql = "select sum(sal) from emp where deptno = 10";
+    sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
+        .withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE).check();
+  }
+
+  @Test void testSingleStarAggregate() {
+    String sql = "select count(*) from emp where deptno = 10";
+    sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
+        .withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE).check();
+  }
+
+  @Test void testMultiAggregates() {
+    String sql = "select sum(sal), min(sal), max(sal), count(*) from emp where 
deptno = 10";
+    sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
+        .withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE).check();
+  }
+
+  @Test void testSingleColumnFilteredAggregate() {
+    String sql = "select sum(sal) filter (where ename = 'Bob') from emp where 
deptno = 10";
+    List<RelOptRule> preRules = new ArrayList<>();
+    preRules.add(AGGREGATE_PROJECT_MERGE);
+    preRules.add(PROJECT_FILTER_TRANSPOSE_WHOLE_PROJECT_EXPRESSIONS);
+    sql(sql).withPre(HepProgram.builder().addRuleCollection(preRules).build())
+        .withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE,
+            CoreRules.PROJECT_MERGE).check();
+  }
+
+  @Test void testAggregateNoSupportingFilter() {
+    String sql = "select single_value(sal) from emp where deptno = 10";
+    sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
+        .withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE)
+        .checkUnchanged();
+  }
+
+  @Test void testSingleColumnAggregateWithGroupBy() {
+    String sql = "select sum(sal) from emp where deptno = 10 group by job";
+    sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
+        .withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE)
+        .checkUnchanged();
+  }
+
+  @Test void testSingleColumnAggregateWithGroupingSets() {
+    String sql =
+        "select sum(sal) from emp where deptno = 10 group by grouping sets 
((job), (ename))";
+    sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
+        .withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE)
+        .checkUnchanged();
+  }
+
+  @Test void testSingleColumnAggregateWithEmptyGroupBy() {
+    String sql = "select sum(sal) from emp where deptno = 10 group by ()";
+    sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
+        .withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE).check();
+  }
+
+  @Test void testSingleColumnAggregateWithEmptyGroupingSets() {
+    String sql = "select sum(sal) from emp where deptno = 10 group by grouping 
sets (())";
+    sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
+        .withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE).check();
+  }
+
+  @AfterAll static void checkActualAndReferenceFiles() {
+    fixture().diffRepos.checkActualAndReferenceFiles();
+  }
+}
diff --git 
a/core/src/test/resources/org/apache/calcite/test/AggregateFilterToFilteredAggregateRuleTest.xml
 
b/core/src/test/resources/org/apache/calcite/test/AggregateFilterToFilteredAggregateRuleTest.xml
new file mode 100644
index 0000000000..b31112c3f3
--- /dev/null
+++ 
b/core/src/test/resources/org/apache/calcite/test/AggregateFilterToFilteredAggregateRuleTest.xml
@@ -0,0 +1,173 @@
+<?xml version="1.0" ?>
+<!--
+  ~ Licensed to the Apache Software Foundation (ASF) under one or more
+  ~ contributor license agreements.  See the NOTICE file distributed with
+  ~ this work for additional information regarding copyright ownership.
+  ~ The ASF licenses this file to you under the Apache License, Version 2.0
+  ~ (the "License"); you may not use this file except in compliance with
+  ~ the License.  You may obtain a copy of the License at
+  ~
+  ~ http://www.apache.org/licenses/LICENSE-2.0
+  ~
+  ~ Unless required by applicable law or agreed to in writing, software
+  ~ distributed under the License is distributed on an "AS IS" BASIS,
+  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  ~ See the License for the specific language governing permissions and
+  ~ limitations under the License.
+  -->
+<Root>
+  <TestCase name="testAggregateNoSupportingFilter">
+    <Resource name="sql">
+      <![CDATA[select single_value(sal) from emp where deptno = 10]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SINGLE_VALUE($5)])
+  LogicalFilter(condition=[=($7, 10)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testMultiAggregates">
+    <Resource name="sql">
+      <![CDATA[select sum(sal), min(sal), max(sal), count(*) from emp where 
deptno = 10]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($5)], EXPR$1=[MIN($5)], 
EXPR$2=[MAX($5)], EXPR$3=[COUNT()])
+  LogicalFilter(condition=[=($7, 10)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($0) FILTER $1], EXPR$1=[MIN($0) 
FILTER $1], EXPR$2=[MAX($0) FILTER $1], EXPR$3=[COUNT() FILTER $1])
+  LogicalProject(SAL=[$5], $f9=[=($7, 10)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testSingleColumnAggregate">
+    <Resource name="sql">
+      <![CDATA[select sum(sal) from emp where deptno = 10]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($5)])
+  LogicalFilter(condition=[=($7, 10)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($0) FILTER $1])
+  LogicalProject(SAL=[$5], $f9=[=($7, 10)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testSingleColumnAggregateWithEmptyGroupBy">
+    <Resource name="sql">
+      <![CDATA[select sum(sal) from emp where deptno = 10 group by ()]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($5)])
+  LogicalFilter(condition=[=($7, 10)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($0) FILTER $1])
+  LogicalProject(SAL=[$5], $f9=[=($7, 10)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testSingleColumnAggregateWithEmptyGroupingSets">
+    <Resource name="sql">
+      <![CDATA[select sum(sal) from emp where deptno = 10 group by grouping 
sets (())]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($5)])
+  LogicalFilter(condition=[=($7, 10)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($0) FILTER $1])
+  LogicalProject(SAL=[$5], $f9=[=($7, 10)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testSingleColumnAggregateWithGroupBy">
+    <Resource name="sql">
+      <![CDATA[select sum(sal) from emp where deptno = 10 group by job]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalProject(EXPR$0=[$1])
+  LogicalAggregate(group=[{2}], EXPR$0=[SUM($5)])
+    LogicalFilter(condition=[=($7, 10)])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testSingleColumnAggregateWithGroupingSets">
+    <Resource name="sql">
+      <![CDATA[select sum(sal) from emp where deptno = 10 group by grouping 
sets ((job), (ename))]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalProject(EXPR$0=[$2])
+  LogicalProject(JOB=[$1], ENAME=[$0], EXPR$0=[$2])
+    LogicalAggregate(group=[{1, 2}], groups=[[{1}, {2}]], EXPR$0=[SUM($5)])
+      LogicalFilter(condition=[=($7, 10)])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testSingleColumnFilteredAggregate">
+    <Resource name="sql">
+      <![CDATA[select sum(sal) filter (where ename = 'Bob') from emp where 
deptno = 10]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($0) FILTER $2])
+  LogicalFilter(condition=[=($1, 10)])
+    LogicalProject(SAL=[$5], DEPTNO=[$7], $f2=[=($1, 'Bob')])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($0) FILTER $1])
+  LogicalProject(SAL=[$5], $f3=[AND(=($7, 10), =($1, 'Bob'))])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testSingleStarAggregate">
+    <Resource name="sql">
+      <![CDATA[select count(*) from emp where deptno = 10]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
+  LogicalFilter(condition=[=($7, 10)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[COUNT() FILTER $0])
+  LogicalProject($f9=[=($7, 10)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+</Root>

Reply via email to