This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit 5abedb1778f025078db567c5c386315a5625fc51
Author: Julian Hyde <[email protected]>
AuthorDate: Tue Aug 4 21:47:30 2020 -0700

    [CALCITE-4154] Add a rule, ProjectAggregateMergeRule, to merge a Project 
onto an Aggregate
    
    Rule also converts COALESCE(SUM(x), 0) to SUM0(x).
---
 .../java/org/apache/calcite/plan/RelOptRules.java  |   1 +
 .../org/apache/calcite/rel/rules/CoreRules.java    |   5 +
 .../rel/rules/ProjectAggregateMergeRule.java       | 198 +++++++++++++++++++++
 .../org/apache/calcite/test/RelOptRulesTest.java   |  46 +++++
 .../org/apache/calcite/test/RelOptRulesTest.xml    |  89 +++++++++
 5 files changed, 339 insertions(+)

diff --git a/core/src/main/java/org/apache/calcite/plan/RelOptRules.java 
b/core/src/main/java/org/apache/calcite/plan/RelOptRules.java
index 5e33487..b8e1164 100644
--- a/core/src/main/java/org/apache/calcite/plan/RelOptRules.java
+++ b/core/src/main/java/org/apache/calcite/plan/RelOptRules.java
@@ -123,6 +123,7 @@ public class RelOptRules {
       CoreRules.AGGREGATE_REMOVE,
       CoreRules.UNION_TO_DISTINCT,
       CoreRules.PROJECT_REMOVE,
+      CoreRules.PROJECT_AGGREGATE_MERGE,
       CoreRules.AGGREGATE_JOIN_TRANSPOSE,
       CoreRules.AGGREGATE_MERGE,
       CoreRules.AGGREGATE_PROJECT_MERGE,
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java 
b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
index 2861cd3..9411f71 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
@@ -340,6 +340,11 @@ public class CoreRules {
   public static final UnionMergeRule MINUS_MERGE =
       UnionMergeRule.Config.MINUS.toRule();
 
+  /** Rule that matches a {@link Project} on an {@link Aggregate},
+   * projecting away aggregate calls that are not used. */
+  public static final ProjectAggregateMergeRule PROJECT_AGGREGATE_MERGE =
+      ProjectAggregateMergeRule.Config.DEFAULT.toRule();
+
   /** Rule that merges a {@link LogicalProject} and a {@link LogicalCalc}.
    *
    * @see #FILTER_CALC_MERGE */
diff --git 
a/core/src/main/java/org/apache/calcite/rel/rules/ProjectAggregateMergeRule.java
 
b/core/src/main/java/org/apache/calcite/rel/rules/ProjectAggregateMergeRule.java
new file mode 100644
index 0000000..33207f6
--- /dev/null
+++ 
b/core/src/main/java/org/apache/calcite/rel/rules/ProjectAggregateMergeRule.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.rel.rules;
+
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexPermuteInputsShuttle;
+import org.apache.calcite.rex.RexShuttle;
+import org.apache.calcite.rex.RexVisitorImpl;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.mapping.MappingType;
+import org.apache.calcite.util.mapping.Mappings;
+
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/**
+ * Planner rule that matches a {@link Project} on a {@link Aggregate}
+ * and projects away aggregate calls that are not used.
+ *
+ * <p>Also converts {@code NVL(SUM(x), 0)} to {@code SUM0(x)}.
+ *
+ * @see CoreRules#PROJECT_AGGREGATE_MERGE
+ */
+public class ProjectAggregateMergeRule
+    extends RelRule<ProjectAggregateMergeRule.Config>
+    implements TransformationRule {
+
+  /** Creates a ProjectAggregateMergeRule. */
+  protected ProjectAggregateMergeRule(Config config) {
+    super(config);
+  }
+
+  @Override public void onMatch(RelOptRuleCall call) {
+    final Project project = call.rel(0);
+    final Aggregate aggregate = call.rel(1);
+    final RelOptCluster cluster = aggregate.getCluster();
+
+    // Do a quick check. If all aggregate calls are used, and there are no CASE
+    // expressions, there is nothing to do.
+    final ImmutableBitSet bits =
+        RelOptUtil.InputFinder.bits(project.getProjects(), null);
+    if (bits.contains(
+        ImmutableBitSet.range(aggregate.getGroupCount(),
+            aggregate.getRowType().getFieldCount()))
+        && kindCount(project.getProjects(), SqlKind.CASE) == 0) {
+      return;
+    }
+
+    // Replace 'COALESCE(SUM(x), 0)' with 'SUM0(x)' wherever it occurs.
+    // Add 'SUM0(x)' to the aggregate call list, if necessary.
+    final List<AggregateCall> aggCallList =
+        new ArrayList<>(aggregate.getAggCallList());
+    final RexShuttle shuttle = new RexShuttle() {
+      @Override public RexNode visitCall(RexCall call) {
+        switch (call.getKind()) {
+        case CASE:
+          // Do we have "CASE(IS NOT NULL($0), CAST($0):INTEGER NOT NULL, 0)"?
+          final List<RexNode> operands = call.operands;
+          if (operands.size() == 3
+              && operands.get(0).getKind() == SqlKind.IS_NOT_NULL
+              && ((RexCall) operands.get(0)).operands.get(0).getKind()
+              == SqlKind.INPUT_REF
+              && operands.get(1).getKind() == SqlKind.CAST
+              && ((RexCall) operands.get(1)).operands.get(0).getKind()
+              == SqlKind.INPUT_REF
+              && operands.get(2).getKind() == SqlKind.LITERAL) {
+            final RexCall isNotNull = (RexCall) operands.get(0);
+            final RexInputRef ref0 = (RexInputRef) isNotNull.operands.get(0);
+            final RexCall cast = (RexCall) operands.get(1);
+            final RexInputRef ref1 = (RexInputRef) cast.operands.get(0);
+            final RexLiteral literal = (RexLiteral) operands.get(2);
+            if (ref0.getIndex() == ref1.getIndex()
+                && 
literal.getValueAs(BigDecimal.class).equals(BigDecimal.ZERO)) {
+              final int aggCallIndex =
+                  ref1.getIndex() - aggregate.getGroupCount();
+              if (aggCallIndex >= 0) {
+                final AggregateCall aggCall =
+                    aggregate.getAggCallList().get(aggCallIndex);
+                if (aggCall.getAggregation().getKind() == SqlKind.SUM) {
+                  int j =
+                      findSum0(cluster.getTypeFactory(), aggCall, aggCallList);
+                  return cluster.getRexBuilder().makeInputRef(call.type, j);
+                }
+              }
+            }
+          }
+        }
+        return super.visitCall(call);
+      }
+    };
+    final List<RexNode> projects2 = shuttle.visitList(project.getProjects());
+    final ImmutableBitSet bits2 =
+        RelOptUtil.InputFinder.bits(projects2, null);
+
+    // Build the mapping that we will apply to the project expressions.
+    final Mappings.TargetMapping mapping =
+        Mappings.create(MappingType.FUNCTION,
+            aggregate.getGroupCount() + aggCallList.size(), -1);
+    int j = 0;
+    for (int i = 0; i < mapping.getSourceCount(); i++) {
+      if (i < aggregate.getGroupCount()) {
+        // Field is a group key. All group keys are retained.
+        mapping.set(i, j++);
+      } else if (bits2.get(i)) {
+        // Field is an aggregate call. It is used.
+        mapping.set(i, j++);
+      } else {
+        // Field is an aggregate call. It is not used. Remove it.
+        aggCallList.remove(j - aggregate.getGroupCount());
+      }
+    }
+
+    final RelBuilder builder = call.builder();
+    builder.push(aggregate.getInput());
+    builder.aggregate(
+        builder.groupKey(aggregate.getGroupSet(),
+            (Iterable<ImmutableBitSet>) aggregate.groupSets), aggCallList);
+    builder.project(
+        RexPermuteInputsShuttle.of(mapping).visitList(projects2));
+    call.transformTo(builder.build());
+  }
+
+  /** Given a call to SUM, finds a call to SUM0 with identical arguments,
+   * or creates one and adds it to the list. Returns the index. */
+  private static int findSum0(RelDataTypeFactory typeFactory, AggregateCall 
sum,
+      List<AggregateCall> aggCallList) {
+    final AggregateCall sum0 =
+        AggregateCall.create(SqlStdOperatorTable.SUM0, sum.isDistinct(),
+            sum.isApproximate(), sum.ignoreNulls(), sum.getArgList(),
+            sum.filterArg, sum.collation,
+            typeFactory.createTypeWithNullability(sum.type, false), null);
+    final int i = aggCallList.indexOf(sum0);
+    if (i >= 0) {
+      return i;
+    }
+    aggCallList.add(sum0);
+    return aggCallList.size() - 1;
+  }
+
+  /** Returns the number of calls of a given kind in a list of expressions. */
+  private static int kindCount(Iterable<? extends RexNode> nodes,
+      final SqlKind kind) {
+    final AtomicInteger kindCount = new AtomicInteger(0);
+    new RexVisitorImpl<Void>(true) {
+      @Override public Void visitCall(RexCall call) {
+        if (call.getKind() == kind) {
+          kindCount.incrementAndGet();
+        }
+        return super.visitCall(call);
+      }
+    }.visitEach(nodes);
+    return kindCount.get();
+  }
+
+  /** Rule configuration. */
+  public interface Config extends RelRule.Config {
+    Config DEFAULT = EMPTY
+        .withOperandSupplier(b0 ->
+            b0.operand(Project.class)
+                .oneInput(b1 ->
+                    b1.operand(Aggregate.class).anyInputs()))
+        .as(Config.class);
+
+    @Override default ProjectAggregateMergeRule toRule() {
+      return new ProjectAggregateMergeRule(this);
+    }
+  }
+}
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java 
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index c7a464a..e35d0dc 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -4740,6 +4740,52 @@ class RelOptRulesTest extends RelOptTestBase {
         .checkUnchanged();
   }
 
+  /** Tests that ProjectAggregateMergeRule removes unused aggregate calls but
+   * not group keys. */
+  @Test void testProjectAggregateMerge() {
+    final String sql = "select deptno + ss\n"
+        + "from (\n"
+        + "  select job, deptno, min(sal) as ms, sum(sal) as ss\n"
+        + "  from sales.emp\n"
+        + "  group by job, deptno)";
+    sql(sql).withRule(CoreRules.PROJECT_AGGREGATE_MERGE)
+        .check();
+  }
+
+  /** Tests that ProjectAggregateMergeRule does nothing when all aggregate 
calls
+   * are referenced. */
+  @Test void testProjectAggregateMergeNoOp() {
+    final String sql = "select deptno + ss + ms\n"
+        + "from (\n"
+        + "  select job, deptno, min(sal) as ms, sum(sal) as ss\n"
+        + "  from sales.emp\n"
+        + "  group by job, deptno)";
+    sql(sql).withRule(CoreRules.PROJECT_AGGREGATE_MERGE)
+        .checkUnchanged();
+  }
+
+  /** Tests that ProjectAggregateMergeRule converts {@code COALESCE(SUM(x), 0)}
+   * into {@code SUM0(x)}. */
+  @Test void testProjectAggregateMergeSum0() {
+    final String sql = "select coalesce(sum_sal, 0) as ss0\n"
+        + "from (\n"
+        + "  select sum(sal) as sum_sal\n"
+        + "  from sales.emp)";
+    sql(sql).withRule(CoreRules.PROJECT_AGGREGATE_MERGE)
+        .check();
+  }
+
+  /** As {@link #testProjectAggregateMergeSum0()} but there is another use of
+   * {@code SUM} that cannot be converted to {@code SUM0}. */
+  @Test void testProjectAggregateMergeSum0AndSum() {
+    final String sql = "select sum_sal * 2, coalesce(sum_sal, 0) as ss0\n"
+        + "from (\n"
+        + "  select sum(sal) as sum_sal\n"
+        + "  from sales.emp)";
+    sql(sql).withRule(CoreRules.PROJECT_AGGREGATE_MERGE)
+        .check();
+  }
+
   /**
    * Test case for AggregateMergeRule, should merge 2 aggregates
    * into a single aggregate.
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 7b9f367..071d0f8 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -5202,6 +5202,95 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], 
MGR=[$3], HIREDATE=[$4], SAL=[$
 ]]>
         </Resource>
     </TestCase>
+    <TestCase name="testProjectAggregateMerge">
+        <Resource name="sql">
+            <![CDATA[select deptno + ss
+from (
+  select job, deptno, min(sal) as ms, sum(sal) as ss
+  from sales.emp
+  group by job, deptno)]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(EXPR$0=[+($1, $3)])
+  LogicalAggregate(group=[{0, 1}], MS=[MIN($2)], SS=[SUM($2)])
+    LogicalProject(JOB=[$2], DEPTNO=[$7], SAL=[$5])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject($f0=[+($1, $2)])
+  LogicalAggregate(group=[{0, 1}], SS=[SUM($2)])
+    LogicalProject(JOB=[$2], DEPTNO=[$7], SAL=[$5])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testProjectAggregateMergeNoOp">
+        <Resource name="sql">
+            <![CDATA[select deptno + ss + ms
+from (
+  select job, deptno, min(sal) as ms, sum(sal) as ss
+  from sales.emp
+  group by job, deptno)]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(EXPR$0=[+(+($1, $3), $2)])
+  LogicalAggregate(group=[{0, 1}], MS=[MIN($2)], SS=[SUM($2)])
+    LogicalProject(JOB=[$2], DEPTNO=[$7], SAL=[$5])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testProjectAggregateMergeSum0">
+        <Resource name="sql">
+            <![CDATA[select coalesce(sum_sal, 0) as ss0
+from (
+  select sum(sal) as sum_sal
+  from sales.emp)]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(SS0=[CASE(IS NOT NULL($0), CAST($0):INTEGER NOT NULL, 0)])
+  LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)])
+    LogicalProject(SAL=[$5])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalAggregate(group=[{}], agg#0=[$SUM0($0)])
+  LogicalProject(SAL=[$5])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testProjectAggregateMergeSum0AndSum">
+        <Resource name="sql">
+            <![CDATA[select sum_sal * 2, coalesce(sum_sal, 0) as ss0
+from (
+  select sum(sal) as sum_sal
+  from sales.emp)]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(EXPR$0=[*($0, 2)], SS0=[CASE(IS NOT NULL($0), CAST($0):INTEGER 
NOT NULL, 0)])
+  LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)])
+    LogicalProject(SAL=[$5])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject($f0=[*($0, 2)], $f1=[$1])
+  LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], agg#1=[$SUM0($0)])
+    LogicalProject(SAL=[$5])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
     <TestCase name="testRemoveSemiJoinWithFilter">
         <Resource name="sql">
             <![CDATA[select e.ename from emp e, dept d

Reply via email to