This is an automated email from the ASF dual-hosted git repository. jcamacho pushed a commit to branch CALCITE-2912 in repository https://gitbox.apache.org/repos/asf/calcite.git
commit 1e81f956a4736b70c258b5594850236a4100e2a6 Author: Jesus Camacho Rodriguez <[email protected]> AuthorDate: Tue Mar 12 15:53:52 2019 -0700 [CALCITE-2912] --- .../rules/AggregateProjectPullUpConstantsRule.java | 35 ++++-- .../calcite/rel/rules/AggregateReduceRule.java | 127 +++++++++++++++++++++ .../apache/calcite/test/MaterializationTest.java | 16 +++ .../org/apache/calcite/test/RelOptRulesTest.java | 10 +- .../org/apache/calcite/test/RelOptRulesTest.xml | 31 ++--- 5 files changed, 192 insertions(+), 27 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateProjectPullUpConstantsRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateProjectPullUpConstantsRule.java index a12e6d1..9a65488 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateProjectPullUpConstantsRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateProjectPullUpConstantsRule.java @@ -31,11 +31,13 @@ import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.tools.RelBuilderFactory; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.Pair; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; import java.util.NavigableMap; @@ -100,7 +102,7 @@ public class AggregateProjectPullUpConstantsRule extends RelOptRule { assert !aggregate.indicator : "predicate ensured no grouping sets"; final int groupCount = aggregate.getGroupCount(); - if (groupCount == 1) { + if (groupCount < 1) { // No room for optimization since we cannot convert from non-empty // GROUP BY list to the empty one. return; @@ -127,14 +129,7 @@ public class AggregateProjectPullUpConstantsRule extends RelOptRule { return; } - if (groupCount == map.size()) { - // At least a single item in group by is required. - // Otherwise "GROUP BY 1, 2" might be altered to "GROUP BY ()". - // Removing of the first element is not optimal here, - // however it will allow us to use fast path below (just trim - // groupCount). - map.remove(map.navigableKeySet().first()); - } + final boolean empty = groupCount == map.size(); ImmutableBitSet newGroupSet = aggregate.getGroupSet(); for (int key : map.keySet()) { @@ -154,7 +149,25 @@ public class AggregateProjectPullUpConstantsRule extends RelOptRule { aggCall.adaptTo(input, aggCall.getArgList(), aggCall.filterArg, groupCount, newGroupCount)); } - relBuilder.aggregate(relBuilder.groupKey(newGroupSet), newAggCalls); + + // Create aggregate operator. + if (empty) { + // If empty, create an additional count(*) field + Aggregate tmpAggregate = (Aggregate) relBuilder + .aggregate(relBuilder.groupKey(), relBuilder.countStar(null)) + .build(); + newAggCalls.add(tmpAggregate.getAggCallList().get(0)); + // Reset stack and create new aggregate call + relBuilder.push(tmpAggregate.getInput()); + relBuilder.aggregate(relBuilder.groupKey(), newAggCalls); + // Add a filter on the new count(*) != 0 + relBuilder.filter( + rexBuilder.makeCall(SqlStdOperatorTable.NOT_EQUALS, + relBuilder.field(relBuilder.peek().getRowType().getFieldCount() - 1), + rexBuilder.makeBigintLiteral(BigDecimal.ZERO))); + } else { + relBuilder.aggregate(relBuilder.groupKey(newGroupSet), newAggCalls); + } // Create a projection back again. List<Pair<RexNode, String>> projects = new ArrayList<>(); @@ -186,6 +199,8 @@ public class AggregateProjectPullUpConstantsRule extends RelOptRule { projects.add(Pair.of(expr, field.getName())); } relBuilder.project(Pair.left(projects), Pair.right(projects)); // inverse + // Create top Project fixing nullability of fields + relBuilder.convert(aggregate.getRowType(), false); call.transformTo(relBuilder.build()); } diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceRule.java new file mode 100644 index 0000000..4d237de --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceRule.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.tools.RelBuilder; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Planner rule that reduces aggregate functions in + * {@link org.apache.calcite.rel.core.Aggregate}s to simpler forms. + * + * <p>Rewrites: + * <ul> + * + * <li>COUNT(x) → COUNT(*) if x is not nullable + * </ul> + * + * It also removes duplicate aggregate calls. + */ +public class AggregateReduceRule extends RelOptRule { + + /** The singleton. */ + public static final AggregateReduceRule INSTANCE = + new AggregateReduceRule(); + + /** Private constructor. */ + private AggregateReduceRule() { + super(operand(LogicalAggregate.class, any()), + RelFactories.LOGICAL_BUILDER, null); + } + + @Override public void onMatch(RelOptRuleCall call) { + final RelBuilder relBuilder = call.builder(); + final Aggregate aggRel = call.rel(0); + final RexBuilder rexBuilder = aggRel.getCluster().getRexBuilder(); + + // We try to rewrite COUNT(x) into COUNT(*) if x is not nullable. + // We remove duplicate aggregate calls as well. + boolean rewrite = false; + boolean identity = true; + final Map<AggregateCall, Integer> mapping = new HashMap<>(); + final List<Integer> indexes = new ArrayList<>(); + final List<AggregateCall> aggCalls = aggRel.getAggCallList(); + final List<AggregateCall> newAggCalls = new ArrayList<>(aggCalls.size()); + int nextIdx = aggRel.getGroupCount() + aggRel.getIndicatorCount(); + for (int i = 0; i < aggCalls.size(); i++) { + AggregateCall aggCall = aggCalls.get(i); + if (aggCall.getAggregation().getKind() == SqlKind.COUNT && !aggCall.isDistinct()) { + final List<Integer> args = aggCall.getArgList(); + final List<Integer> nullableArgs = new ArrayList<>(args.size()); + for (int arg : args) { + if (aggRel.getInput().getRowType().getFieldList().get(arg).getType().isNullable()) { + nullableArgs.add(arg); + } + } + if (nullableArgs.size() != args.size()) { + aggCall = aggCall.copy(nullableArgs, aggCall.filterArg, aggCall.collation); + rewrite = true; + } + } + Integer idx = mapping.get(aggCall); + if (idx == null) { + newAggCalls.add(aggCall); + idx = nextIdx++; + mapping.put(aggCall, idx); + } else { + rewrite = true; + identity = false; + } + indexes.add(idx); + } + + if (rewrite) { + // We trigger the transform + final Aggregate newAggregate = aggRel.copy(aggRel.getTraitSet(), aggRel.getInput(), + aggRel.indicator, aggRel.getGroupSet(), aggRel.getGroupSets(), + newAggCalls); + if (identity) { + call.transformTo(newAggregate); + } else { + final int offset = aggRel.getGroupCount() + aggRel.getIndicatorCount(); + final List<RexNode> projList = new ArrayList<>(); + for (int i = 0; i < offset; ++i) { + projList.add( + rexBuilder.makeInputRef( + aggRel.getRowType().getFieldList().get(i).getType(), i)); + } + for (int i = offset; i < aggRel.getRowType().getFieldCount(); ++i) { + projList.add( + rexBuilder.makeInputRef( + aggRel.getRowType().getFieldList().get(i).getType(), indexes.get(i - offset))); + } + call.transformTo(relBuilder.push(newAggregate).project(projList).build()); + } + } + } + +} + +// End AggregateReduceRule.java diff --git a/core/src/test/java/org/apache/calcite/test/MaterializationTest.java b/core/src/test/java/org/apache/calcite/test/MaterializationTest.java index 3126db2..5192bbe 100644 --- a/core/src/test/java/org/apache/calcite/test/MaterializationTest.java +++ b/core/src/test/java/org/apache/calcite/test/MaterializationTest.java @@ -2193,6 +2193,22 @@ public class MaterializationTest { + " EnumerableTableScan(table=[[hr, m0]]")); } +// @Test public void testAggregateMaterializationWithConstantFilter() { +// checkMaterialize( +// "select \"deptno\", \"name\", count(*) as c\n" +// + "from \"emps\" group by \"deptno\", \"name\"", +// "select \"name\", count(*) as c\n" +// + "from \"emps\" where \"name\" = 'a_name' group by \"name\""); +// } +// +// @Test public void testAggregateMaterializationWithConstantFilter2() { +// checkMaterialize( +// "select \"deptno\", \"name\", \"salary\", count(*) as c\n" +// + "from \"emps\" group by \"deptno\", \"name\", \"salary\"", +// "select \"deptno\", \"name\", count(*) as c\n" +// + "from \"emps\" where \"name\" = 'a_name' group by \"deptno\", \"name\""); +// } + @Test public void testMaterializationSubstitution() { String q = "select *\n" + "from (select * from \"emps\" where \"empid\" < 300)\n" diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index ef2b705..a8ec4fd 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -56,6 +56,7 @@ import org.apache.calcite.rel.rules.AggregateJoinTransposeRule; import org.apache.calcite.rel.rules.AggregateProjectMergeRule; import org.apache.calcite.rel.rules.AggregateProjectPullUpConstantsRule; import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule; +import org.apache.calcite.rel.rules.AggregateReduceRule; import org.apache.calcite.rel.rules.AggregateUnionAggregateRule; import org.apache.calcite.rel.rules.AggregateUnionTransposeRule; import org.apache.calcite.rel.rules.AggregateValuesRule; @@ -3822,17 +3823,20 @@ public class RelOptRulesTest extends RelOptTestBase { checkPlanning(new HepPlanner(program), sql); } - /** Tests {@link AggregateProjectPullUpConstantsRule} where reduction is not - * possible because "deptno" is the only key. */ + /** Tests {@link AggregateProjectPullUpConstantsRule} where all columns can be + * reduced. */ @Test public void testAggregateConstantKeyRule2() { final HepProgram program = new HepProgramBuilder() .addRuleInstance(AggregateProjectPullUpConstantsRule.INSTANCE2) + .addRuleInstance(AggregateReduceRule.INSTANCE) + .addRuleInstance(FilterProjectTransposeRule.INSTANCE) + .addRuleInstance(ProjectMergeRule.INSTANCE) .build(); final String sql = "select count(*) as c\n" + "from sales.emp\n" + "where deptno = 10\n" + "group by deptno"; - checkPlanUnchanged(new HepPlanner(program), sql); + checkPlanning(new HepPlanner(program), sql); } /** Tests {@link AggregateProjectPullUpConstantsRule} where both keys are diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index c94842c..fb03df2 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -4047,10 +4047,11 @@ LogicalAggregate(group=[{0, 1}], EXPR$2=[MAX($2)]) </Resource> <Resource name="planAfter"> <![CDATA[ -LogicalProject(EXPR$0=[$0], EXPR$1=[+(2, 3)], EXPR$2=[$1]) - LogicalAggregate(group=[{0}], EXPR$2=[MAX($2)]) - LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], MGR=[$3]) - LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], EXPR$2=[$0]) + LogicalFilter(condition=[<>($1, 0)]) + LogicalAggregate(group=[{}], EXPR$2=[MAX($2)], agg#1=[COUNT()]) + LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], MGR=[$3]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]> </Resource> </TestCase> @@ -4069,10 +4070,11 @@ LogicalAggregate(group=[{0, 1}], EXPR$2=[MAX($2)]) </Resource> <Resource name="planAfter"> <![CDATA[ -LogicalProject(EXPR$0=[$0], EXPR$1=[+(2, 3)], EXPR$2=[$1]) - LogicalAggregate(group=[{0}], EXPR$2=[MAX($2)]) - LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], FIVE=[5]) - LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], EXPR$2=[CAST($0):INTEGER NOT NULL]) + LogicalFilter(condition=[<>($1, 0)]) + LogicalAggregate(group=[{}], EXPR$2=[MAX($2)], agg#1=[COUNT()]) + LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], FIVE=[5]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]> </Resource> </TestCase> @@ -4091,10 +4093,11 @@ LogicalAggregate(group=[{0, 1}], EXPR$2=[MAX($2)]) </Resource> <Resource name="planAfter"> <![CDATA[ -LogicalProject(EXPR$0=[$0], EXPR$1=[+(2, 3)], EXPR$2=[$1]) - LogicalAggregate(group=[{0}], EXPR$2=[MAX($2)]) - LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], $f2=[5]) - LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], EXPR$2=[CAST($0):INTEGER NOT NULL]) + LogicalFilter(condition=[<>($1, 0)]) + LogicalAggregate(group=[{}], EXPR$2=[MAX($2)], agg#1=[COUNT()]) + LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], $f2=[5]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]> </Resource> </TestCase> @@ -7977,8 +7980,8 @@ LogicalProject(C=[$1]) </Resource> <Resource name="planAfter"> <![CDATA[ -LogicalProject(C=[$1]) - LogicalAggregate(group=[{0}], C=[COUNT()]) +LogicalFilter(condition=[<>($0, 0)]) + LogicalAggregate(group=[{}], C=[COUNT()]) LogicalProject(DEPTNO=[$7]) LogicalFilter(condition=[=($7, 10)]) LogicalTableScan(table=[[CATALOG, SALES, EMP]])
