This is an automated email from the ASF dual-hosted git repository. jhyde pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/calcite.git
commit 5abedb1778f025078db567c5c386315a5625fc51 Author: Julian Hyde <[email protected]> AuthorDate: Tue Aug 4 21:47:30 2020 -0700 [CALCITE-4154] Add a rule, ProjectAggregateMergeRule, to merge a Project onto an Aggregate Rule also converts COALESCE(SUM(x), 0) to SUM0(x). --- .../java/org/apache/calcite/plan/RelOptRules.java | 1 + .../org/apache/calcite/rel/rules/CoreRules.java | 5 + .../rel/rules/ProjectAggregateMergeRule.java | 198 +++++++++++++++++++++ .../org/apache/calcite/test/RelOptRulesTest.java | 46 +++++ .../org/apache/calcite/test/RelOptRulesTest.xml | 89 +++++++++ 5 files changed, 339 insertions(+) diff --git a/core/src/main/java/org/apache/calcite/plan/RelOptRules.java b/core/src/main/java/org/apache/calcite/plan/RelOptRules.java index 5e33487..b8e1164 100644 --- a/core/src/main/java/org/apache/calcite/plan/RelOptRules.java +++ b/core/src/main/java/org/apache/calcite/plan/RelOptRules.java @@ -123,6 +123,7 @@ public class RelOptRules { CoreRules.AGGREGATE_REMOVE, CoreRules.UNION_TO_DISTINCT, CoreRules.PROJECT_REMOVE, + CoreRules.PROJECT_AGGREGATE_MERGE, CoreRules.AGGREGATE_JOIN_TRANSPOSE, CoreRules.AGGREGATE_MERGE, CoreRules.AGGREGATE_PROJECT_MERGE, diff --git a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java index 2861cd3..9411f71 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java @@ -340,6 +340,11 @@ public class CoreRules { public static final UnionMergeRule MINUS_MERGE = UnionMergeRule.Config.MINUS.toRule(); + /** Rule that matches a {@link Project} on an {@link Aggregate}, + * projecting away aggregate calls that are not used. */ + public static final ProjectAggregateMergeRule PROJECT_AGGREGATE_MERGE = + ProjectAggregateMergeRule.Config.DEFAULT.toRule(); + /** Rule that merges a {@link LogicalProject} and a {@link LogicalCalc}. * * @see #FILTER_CALC_MERGE */ diff --git a/core/src/main/java/org/apache/calcite/rel/rules/ProjectAggregateMergeRule.java b/core/src/main/java/org/apache/calcite/rel/rules/ProjectAggregateMergeRule.java new file mode 100644 index 0000000..33207f6 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/ProjectAggregateMergeRule.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexPermuteInputsShuttle; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexVisitorImpl; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.mapping.MappingType; +import org.apache.calcite.util.mapping.Mappings; + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Planner rule that matches a {@link Project} on a {@link Aggregate} + * and projects away aggregate calls that are not used. + * + * <p>Also converts {@code NVL(SUM(x), 0)} to {@code SUM0(x)}. + * + * @see CoreRules#PROJECT_AGGREGATE_MERGE + */ +public class ProjectAggregateMergeRule + extends RelRule<ProjectAggregateMergeRule.Config> + implements TransformationRule { + + /** Creates a ProjectAggregateMergeRule. */ + protected ProjectAggregateMergeRule(Config config) { + super(config); + } + + @Override public void onMatch(RelOptRuleCall call) { + final Project project = call.rel(0); + final Aggregate aggregate = call.rel(1); + final RelOptCluster cluster = aggregate.getCluster(); + + // Do a quick check. If all aggregate calls are used, and there are no CASE + // expressions, there is nothing to do. + final ImmutableBitSet bits = + RelOptUtil.InputFinder.bits(project.getProjects(), null); + if (bits.contains( + ImmutableBitSet.range(aggregate.getGroupCount(), + aggregate.getRowType().getFieldCount())) + && kindCount(project.getProjects(), SqlKind.CASE) == 0) { + return; + } + + // Replace 'COALESCE(SUM(x), 0)' with 'SUM0(x)' wherever it occurs. + // Add 'SUM0(x)' to the aggregate call list, if necessary. + final List<AggregateCall> aggCallList = + new ArrayList<>(aggregate.getAggCallList()); + final RexShuttle shuttle = new RexShuttle() { + @Override public RexNode visitCall(RexCall call) { + switch (call.getKind()) { + case CASE: + // Do we have "CASE(IS NOT NULL($0), CAST($0):INTEGER NOT NULL, 0)"? + final List<RexNode> operands = call.operands; + if (operands.size() == 3 + && operands.get(0).getKind() == SqlKind.IS_NOT_NULL + && ((RexCall) operands.get(0)).operands.get(0).getKind() + == SqlKind.INPUT_REF + && operands.get(1).getKind() == SqlKind.CAST + && ((RexCall) operands.get(1)).operands.get(0).getKind() + == SqlKind.INPUT_REF + && operands.get(2).getKind() == SqlKind.LITERAL) { + final RexCall isNotNull = (RexCall) operands.get(0); + final RexInputRef ref0 = (RexInputRef) isNotNull.operands.get(0); + final RexCall cast = (RexCall) operands.get(1); + final RexInputRef ref1 = (RexInputRef) cast.operands.get(0); + final RexLiteral literal = (RexLiteral) operands.get(2); + if (ref0.getIndex() == ref1.getIndex() + && literal.getValueAs(BigDecimal.class).equals(BigDecimal.ZERO)) { + final int aggCallIndex = + ref1.getIndex() - aggregate.getGroupCount(); + if (aggCallIndex >= 0) { + final AggregateCall aggCall = + aggregate.getAggCallList().get(aggCallIndex); + if (aggCall.getAggregation().getKind() == SqlKind.SUM) { + int j = + findSum0(cluster.getTypeFactory(), aggCall, aggCallList); + return cluster.getRexBuilder().makeInputRef(call.type, j); + } + } + } + } + } + return super.visitCall(call); + } + }; + final List<RexNode> projects2 = shuttle.visitList(project.getProjects()); + final ImmutableBitSet bits2 = + RelOptUtil.InputFinder.bits(projects2, null); + + // Build the mapping that we will apply to the project expressions. + final Mappings.TargetMapping mapping = + Mappings.create(MappingType.FUNCTION, + aggregate.getGroupCount() + aggCallList.size(), -1); + int j = 0; + for (int i = 0; i < mapping.getSourceCount(); i++) { + if (i < aggregate.getGroupCount()) { + // Field is a group key. All group keys are retained. + mapping.set(i, j++); + } else if (bits2.get(i)) { + // Field is an aggregate call. It is used. + mapping.set(i, j++); + } else { + // Field is an aggregate call. It is not used. Remove it. + aggCallList.remove(j - aggregate.getGroupCount()); + } + } + + final RelBuilder builder = call.builder(); + builder.push(aggregate.getInput()); + builder.aggregate( + builder.groupKey(aggregate.getGroupSet(), + (Iterable<ImmutableBitSet>) aggregate.groupSets), aggCallList); + builder.project( + RexPermuteInputsShuttle.of(mapping).visitList(projects2)); + call.transformTo(builder.build()); + } + + /** Given a call to SUM, finds a call to SUM0 with identical arguments, + * or creates one and adds it to the list. Returns the index. */ + private static int findSum0(RelDataTypeFactory typeFactory, AggregateCall sum, + List<AggregateCall> aggCallList) { + final AggregateCall sum0 = + AggregateCall.create(SqlStdOperatorTable.SUM0, sum.isDistinct(), + sum.isApproximate(), sum.ignoreNulls(), sum.getArgList(), + sum.filterArg, sum.collation, + typeFactory.createTypeWithNullability(sum.type, false), null); + final int i = aggCallList.indexOf(sum0); + if (i >= 0) { + return i; + } + aggCallList.add(sum0); + return aggCallList.size() - 1; + } + + /** Returns the number of calls of a given kind in a list of expressions. */ + private static int kindCount(Iterable<? extends RexNode> nodes, + final SqlKind kind) { + final AtomicInteger kindCount = new AtomicInteger(0); + new RexVisitorImpl<Void>(true) { + @Override public Void visitCall(RexCall call) { + if (call.getKind() == kind) { + kindCount.incrementAndGet(); + } + return super.visitCall(call); + } + }.visitEach(nodes); + return kindCount.get(); + } + + /** Rule configuration. */ + public interface Config extends RelRule.Config { + Config DEFAULT = EMPTY + .withOperandSupplier(b0 -> + b0.operand(Project.class) + .oneInput(b1 -> + b1.operand(Aggregate.class).anyInputs())) + .as(Config.class); + + @Override default ProjectAggregateMergeRule toRule() { + return new ProjectAggregateMergeRule(this); + } + } +} diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index c7a464a..e35d0dc 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -4740,6 +4740,52 @@ class RelOptRulesTest extends RelOptTestBase { .checkUnchanged(); } + /** Tests that ProjectAggregateMergeRule removes unused aggregate calls but + * not group keys. */ + @Test void testProjectAggregateMerge() { + final String sql = "select deptno + ss\n" + + "from (\n" + + " select job, deptno, min(sal) as ms, sum(sal) as ss\n" + + " from sales.emp\n" + + " group by job, deptno)"; + sql(sql).withRule(CoreRules.PROJECT_AGGREGATE_MERGE) + .check(); + } + + /** Tests that ProjectAggregateMergeRule does nothing when all aggregate calls + * are referenced. */ + @Test void testProjectAggregateMergeNoOp() { + final String sql = "select deptno + ss + ms\n" + + "from (\n" + + " select job, deptno, min(sal) as ms, sum(sal) as ss\n" + + " from sales.emp\n" + + " group by job, deptno)"; + sql(sql).withRule(CoreRules.PROJECT_AGGREGATE_MERGE) + .checkUnchanged(); + } + + /** Tests that ProjectAggregateMergeRule converts {@code COALESCE(SUM(x), 0)} + * into {@code SUM0(x)}. */ + @Test void testProjectAggregateMergeSum0() { + final String sql = "select coalesce(sum_sal, 0) as ss0\n" + + "from (\n" + + " select sum(sal) as sum_sal\n" + + " from sales.emp)"; + sql(sql).withRule(CoreRules.PROJECT_AGGREGATE_MERGE) + .check(); + } + + /** As {@link #testProjectAggregateMergeSum0()} but there is another use of + * {@code SUM} that cannot be converted to {@code SUM0}. */ + @Test void testProjectAggregateMergeSum0AndSum() { + final String sql = "select sum_sal * 2, coalesce(sum_sal, 0) as ss0\n" + + "from (\n" + + " select sum(sal) as sum_sal\n" + + " from sales.emp)"; + sql(sql).withRule(CoreRules.PROJECT_AGGREGATE_MERGE) + .check(); + } + /** * Test case for AggregateMergeRule, should merge 2 aggregates * into a single aggregate. diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index 7b9f367..071d0f8 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -5202,6 +5202,95 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$ ]]> </Resource> </TestCase> + <TestCase name="testProjectAggregateMerge"> + <Resource name="sql"> + <![CDATA[select deptno + ss +from ( + select job, deptno, min(sal) as ms, sum(sal) as ss + from sales.emp + group by job, deptno)]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalProject(EXPR$0=[+($1, $3)]) + LogicalAggregate(group=[{0, 1}], MS=[MIN($2)], SS=[SUM($2)]) + LogicalProject(JOB=[$2], DEPTNO=[$7], SAL=[$5]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + <Resource name="planAfter"> + <![CDATA[ +LogicalProject($f0=[+($1, $2)]) + LogicalAggregate(group=[{0, 1}], SS=[SUM($2)]) + LogicalProject(JOB=[$2], DEPTNO=[$7], SAL=[$5]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + </TestCase> + <TestCase name="testProjectAggregateMergeNoOp"> + <Resource name="sql"> + <![CDATA[select deptno + ss + ms +from ( + select job, deptno, min(sal) as ms, sum(sal) as ss + from sales.emp + group by job, deptno)]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalProject(EXPR$0=[+(+($1, $3), $2)]) + LogicalAggregate(group=[{0, 1}], MS=[MIN($2)], SS=[SUM($2)]) + LogicalProject(JOB=[$2], DEPTNO=[$7], SAL=[$5]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + </TestCase> + <TestCase name="testProjectAggregateMergeSum0"> + <Resource name="sql"> + <![CDATA[select coalesce(sum_sal, 0) as ss0 +from ( + select sum(sal) as sum_sal + from sales.emp)]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalProject(SS0=[CASE(IS NOT NULL($0), CAST($0):INTEGER NOT NULL, 0)]) + LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)]) + LogicalProject(SAL=[$5]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + <Resource name="planAfter"> + <![CDATA[ +LogicalAggregate(group=[{}], agg#0=[$SUM0($0)]) + LogicalProject(SAL=[$5]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + </TestCase> + <TestCase name="testProjectAggregateMergeSum0AndSum"> + <Resource name="sql"> + <![CDATA[select sum_sal * 2, coalesce(sum_sal, 0) as ss0 +from ( + select sum(sal) as sum_sal + from sales.emp)]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalProject(EXPR$0=[*($0, 2)], SS0=[CASE(IS NOT NULL($0), CAST($0):INTEGER NOT NULL, 0)]) + LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)]) + LogicalProject(SAL=[$5]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + <Resource name="planAfter"> + <![CDATA[ +LogicalProject($f0=[*($0, 2)], $f1=[$1]) + LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], agg#1=[$SUM0($0)]) + LogicalProject(SAL=[$5]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + </TestCase> <TestCase name="testRemoveSemiJoinWithFilter"> <Resource name="sql"> <