This is an automated email from the ASF dual-hosted git repository. jhyde pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/calcite.git
commit 687b7d8cc8c63259138b43745a9e28ef64483839 Author: Julian Hyde <[email protected]> AuthorDate: Mon Jul 1 12:57:14 2019 -0700 [CALCITE-3144] Add rule, AggregateCaseToFilterRule, that converts "SUM(CASE WHEN b THEN x END)" to "SUM(x) FILTER (WHERE b)" Copied from Apache Druid's CaseFilteredAggregatorRule. Some aggregate functions (e.g. SINGLE_VALUE, GROUPING, GROUP_ID) do not allow filter, so we skip them in the rule. --- .../java/org/apache/calcite/plan/RelOptRules.java | 2 + .../org/apache/calcite/rel/core/AggregateCall.java | 2 + .../rel/rules/AggregateCaseToFilterRule.java | 268 +++++++++++++++++++++ .../calcite/sql/fun/SqlSingleValueAggFunction.java | 4 + .../org/apache/calcite/test/RelOptRulesTest.java | 18 +- .../org/apache/calcite/test/RelOptTestBase.java | 8 +- .../org/apache/calcite/test/RelOptRulesTest.xml | 30 +++ core/src/test/resources/sql/agg.iq | 32 +++ 8 files changed, 361 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/plan/RelOptRules.java b/core/src/main/java/org/apache/calcite/plan/RelOptRules.java index dcb2f7d..cff4b6c 100644 --- a/core/src/main/java/org/apache/calcite/plan/RelOptRules.java +++ b/core/src/main/java/org/apache/calcite/plan/RelOptRules.java @@ -22,6 +22,7 @@ import org.apache.calcite.interpreter.NoneToBindableConverterRule; import org.apache.calcite.linq4j.function.Experimental; import org.apache.calcite.plan.volcano.AbstractConverter; import org.apache.calcite.rel.rules.AbstractMaterializedViewRule; +import org.apache.calcite.rel.rules.AggregateCaseToFilterRule; import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule; import org.apache.calcite.rel.rules.AggregateJoinTransposeRule; import org.apache.calcite.rel.rules.AggregateMergeRule; @@ -121,6 +122,7 @@ public class RelOptRules { FilterJoinRule.FILTER_ON_JOIN, JoinPushExpressionsRule.INSTANCE, AggregateExpandDistinctAggregatesRule.INSTANCE, + AggregateCaseToFilterRule.INSTANCE, AggregateReduceFunctionsRule.INSTANCE, FilterAggregateTransposeRule.INSTANCE, ProjectWindowTransposeRule.INSTANCE, diff --git a/core/src/main/java/org/apache/calcite/rel/core/AggregateCall.java b/core/src/main/java/org/apache/calcite/rel/core/AggregateCall.java index e95f46a..bb3c109 100644 --- a/core/src/main/java/org/apache/calcite/rel/core/AggregateCall.java +++ b/core/src/main/java/org/apache/calcite/rel/core/AggregateCall.java @@ -26,6 +26,7 @@ import org.apache.calcite.sql.type.SqlTypeUtil; import org.apache.calcite.util.mapping.Mapping; import org.apache.calcite.util.mapping.Mappings; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import java.util.List; @@ -100,6 +101,7 @@ public class AggregateCall { this.distinct = distinct; this.approximate = approximate; this.ignoreNulls = ignoreNulls; + Preconditions.checkArgument(filterArg < 0 || aggFunction.allowsFilter()); } //~ Methods ---------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateCaseToFilterRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateCaseToFilterRule.java new file mode 100644 index 0000000..77966ec --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateCaseToFilterRule.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlPostfixOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.tools.RelBuilderFactory; + +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nullable; + +/** + * Rule that converts CASE-style filtered aggregates into true filtered + * aggregates. + * + * <p>For example, + * + * <blockquote> + * <code>SELECT SUM(CASE WHEN gender = 'F' THEN salary END)<br> + * FROM Emp</code> + * </blockquote> + * + * <p>becomes + * + * <blockquote> + * <code>SELECT SUM(salary) FILTER (WHERE gender = 'F')<br> + * FROM Emp</code> + * </blockquote> + */ +public class AggregateCaseToFilterRule extends RelOptRule { + public static final AggregateCaseToFilterRule INSTANCE = + new AggregateCaseToFilterRule(RelFactories.LOGICAL_BUILDER, null); + + /** Creates an AggregateCaseToFilterRule. */ + protected AggregateCaseToFilterRule(RelBuilderFactory relBuilderFactory, + String description) { + super(operand(Aggregate.class, operand(Project.class, any())), + relBuilderFactory, description); + } + + @Override public boolean matches(final RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + final Project project = call.rel(1); + + for (AggregateCall aggregateCall : aggregate.getAggCallList()) { + final int singleArg = soleArgument(aggregateCall); + if (singleArg >= 0 + && isThreeArgCase(project.getProjects().get(singleArg))) { + return true; + } + } + + return false; + } + + @Override public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + final Project project = call.rel(1); + final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); + final List<AggregateCall> newCalls = + new ArrayList<>(aggregate.getAggCallList().size()); + final List<RexNode> newProjects = new ArrayList<>(project.getProjects()); + final List<RexNode> newCasts = new ArrayList<>(); + + for (int fieldNumber : aggregate.getGroupSet()) { + newCasts.add( + rexBuilder.makeInputRef( + project.getProjects().get(fieldNumber).getType(), fieldNumber)); + } + + for (AggregateCall aggregateCall : aggregate.getAggCallList()) { + AggregateCall newCall = + transform(aggregateCall, project, newProjects); + + // Possibly CAST the new aggregator to an appropriate type. + final int i = newCasts.size(); + final RelDataType oldType = + aggregate.getRowType().getFieldList().get(i).getType(); + if (newCall == null) { + newCalls.add(aggregateCall); + newCasts.add(rexBuilder.makeInputRef(oldType, i)); + } else { + newCalls.add(newCall); + newCasts.add( + rexBuilder.makeCast(oldType, + rexBuilder.makeInputRef(newCall.getType(), i))); + } + } + + if (newCalls.equals(aggregate.getAggCallList())) { + return; + } + + final RelBuilder relBuilder = call.builder() + .push(project.getInput()) + .project(newProjects); + + final RelBuilder.GroupKey groupKey = + relBuilder.groupKey(aggregate.getGroupSet(), + aggregate.getGroupSets()); + + relBuilder.aggregate(groupKey, newCalls) + .convert(aggregate.getRowType(), false); + + call.transformTo(relBuilder.build()); + call.getPlanner().setImportance(aggregate, 0.0); + } + + private @Nullable AggregateCall transform(AggregateCall aggregateCall, + Project project, List<RexNode> newProjects) { + final int singleArg = soleArgument(aggregateCall); + if (singleArg < 0) { + return null; + } + + final RexNode rexNode = project.getProjects().get(singleArg); + if (!isThreeArgCase(rexNode)) { + return null; + } + + final RelOptCluster cluster = project.getCluster(); + final RexBuilder rexBuilder = cluster.getRexBuilder(); + final RexCall caseCall = (RexCall) rexNode; + + // If one arg is null and the other is not, reverse them and set "flip", + // which negates the filter. + final boolean flip = RexLiteral.isNullLiteral(caseCall.operands.get(1)) + && !RexLiteral.isNullLiteral(caseCall.operands.get(2)); + final RexNode arg1 = caseCall.operands.get(flip ? 2 : 1); + final RexNode arg2 = caseCall.operands.get(flip ? 1 : 2); + + // Operand 1: Filter + final SqlPostfixOperator op = + flip ? SqlStdOperatorTable.IS_FALSE : SqlStdOperatorTable.IS_TRUE; + final RexNode filterFromCase = + rexBuilder.makeCall(op, caseCall.operands.get(0)); + + // Combine the CASE filter with an honest-to-goodness SQL FILTER, if the + // latter is present. + final RexNode filter; + if (aggregateCall.filterArg >= 0) { + filter = rexBuilder.makeCall(SqlStdOperatorTable.AND, + project.getProjects().get(aggregateCall.filterArg), filterFromCase); + } else { + filter = filterFromCase; + } + + final SqlKind kind = aggregateCall.getAggregation().getKind(); + if (aggregateCall.isDistinct()) { + // Just one style supported: + // COUNT(DISTINCT CASE WHEN x = 'foo' THEN y END) + // => + // COUNT(DISTINCT y) FILTER(WHERE x = 'foo') + + if (kind == SqlKind.COUNT + && RexLiteral.isNullLiteral(arg2)) { + newProjects.add(arg1); + newProjects.add(filter); + return AggregateCall.create(SqlStdOperatorTable.COUNT, true, false, + false, ImmutableList.of(newProjects.size() - 2), + newProjects.size() - 1, RelCollations.EMPTY, + aggregateCall.getType(), aggregateCall.getName()); + } + return null; + } + + // Four styles supported: + // + // A1: AGG(CASE WHEN x = 'foo' THEN cnt END) + // => operands (x = 'foo', cnt, null) + // A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) + // => operands (x = 'foo', cnt, 0); must be SUM + // B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) + // => operands (x = 'foo', 1, 0); must be SUM + // C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) + // => operands (x = 'foo', 'dummy', null) + + if (kind == SqlKind.COUNT // Case C + && arg1.isA(SqlKind.LITERAL) + && !RexLiteral.isNullLiteral(arg1) + && RexLiteral.isNullLiteral(arg2)) { + newProjects.add(filter); + return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false, + false, ImmutableList.of(), newProjects.size() - 1, + RelCollations.EMPTY, aggregateCall.getType(), + aggregateCall.getName()); + } else if (kind == SqlKind.SUM // Case B + && isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1 + && isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) { + + newProjects.add(filter); + final RelDataTypeFactory typeFactory = cluster.getTypeFactory(); + final RelDataType dataType = + typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.BIGINT), false); + return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false, + false, ImmutableList.of(), newProjects.size() - 1, + RelCollations.EMPTY, dataType, aggregateCall.getName()); + } else if ((RexLiteral.isNullLiteral(arg2) // Case A1 + && aggregateCall.getAggregation().allowsFilter()) + || (kind == SqlKind.SUM // Case A2 + && isIntLiteral(arg2) + && RexLiteral.intValue(arg2) == 0)) { + newProjects.add(arg1); + newProjects.add(filter); + return AggregateCall.create(aggregateCall.getAggregation(), false, + false, false, ImmutableList.of(newProjects.size() - 2), + newProjects.size() - 1, RelCollations.EMPTY, + aggregateCall.getType(), aggregateCall.getName()); + } else { + return null; + } + } + + /** Returns the argument, if an aggregate call has a single argument, + * otherwise -1. */ + private static int soleArgument(AggregateCall aggregateCall) { + return aggregateCall.getArgList().size() == 1 + ? aggregateCall.getArgList().get(0) + : -1; + } + + private static boolean isThreeArgCase(final RexNode rexNode) { + return rexNode.getKind() == SqlKind.CASE + && ((RexCall) rexNode).operands.size() == 3; + } + + private static boolean isIntLiteral(final RexNode rexNode) { + return rexNode instanceof RexLiteral + && SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName()); + } +} + +// End AggregateCaseToFilterRule.java diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlSingleValueAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlSingleValueAggFunction.java index 6f4518e..8eb52b7 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlSingleValueAggFunction.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlSingleValueAggFunction.java @@ -59,6 +59,10 @@ public class SqlSingleValueAggFunction extends SqlAggFunction { //~ Methods ---------------------------------------------------------------- + @Override public boolean allowsFilter() { + return false; + } + @SuppressWarnings("deprecation") public List<RelDataType> getParameterTypes(RelDataTypeFactory typeFactory) { return ImmutableList.of(type); diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index 756d223..8e97702 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -61,6 +61,7 @@ import org.apache.calcite.rel.metadata.CachingRelMetadataProvider; import org.apache.calcite.rel.metadata.ChainedRelMetadataProvider; import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider; import org.apache.calcite.rel.metadata.RelMetadataProvider; +import org.apache.calcite.rel.rules.AggregateCaseToFilterRule; import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule; import org.apache.calcite.rel.rules.AggregateExtractProjectRule; import org.apache.calcite.rel.rules.AggregateFilterTransposeRule; @@ -3297,6 +3298,21 @@ public class RelOptRulesTest extends RelOptTestBase { sql(sql).withPre(pre).withRule(rule).checkUnchanged(); } + @Test public void testAggregateCaseToFilter() { + final String sql = "select\n" + + " sum(sal) as sum_sal,\n" + + " count(distinct case\n" + + " when job = 'CLERK'\n" + + " then deptno else null end) as count_distinct_clerk,\n" + + " sum(case when deptno = 10 then sal end) as sum_sal_d10,\n" + + " sum(case when deptno = 20 then sal else 0 end) as sum_sal_d20,\n" + + " sum(case when deptno = 30 then 1 else 0 end) as count_d30,\n" + + " count(case when deptno = 40 then 'x' end) as count_d40,\n" + + " count(case when deptno = 20 then 1 end) as count_d20\n" + + "from emp"; + sql(sql).withRule(AggregateCaseToFilterRule.INSTANCE).check(); + } + @Test public void testPullAggregateThroughUnion() { HepProgram program = new HepProgramBuilder() .addRuleInstance(AggregateUnionAggregateRule.INSTANCE) @@ -5256,7 +5272,7 @@ public class RelOptRulesTest extends RelOptTestBase { /** Test case for * <a href="https://issues.apache.org/jira/browse/CALCITE-3121">[CALCITE-3121] - * VolcanoPlanner hangs due to subquery with dynamic star</a>. */ + * VolcanoPlanner hangs due to sub-query with dynamic star</a>. */ @Test public void testSubQueryWithDynamicStarHang() { String sql = "select n.n_regionkey from (select * from " + "(select * from sales.customer) t) n where n.n_nationkey >1"; diff --git a/core/src/test/java/org/apache/calcite/test/RelOptTestBase.java b/core/src/test/java/org/apache/calcite/test/RelOptTestBase.java index 507edd2..bf263c9 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptTestBase.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptTestBase.java @@ -262,8 +262,12 @@ abstract class RelOptTestBase extends SqlToRelTestBase { transforms); } - public Sql withRule(RelOptRule rule) { - return with(HepProgram.builder().addRuleInstance(rule).build()); + public Sql withRule(RelOptRule... rules) { + final HepProgramBuilder builder = HepProgram.builder(); + for (RelOptRule rule : rules) { + builder.addRuleInstance(rule); + } + return with(builder.build()); } /** Adds a transform that will be applied to {@link #tester} diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index d2d83e0..288dfce 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -16,6 +16,36 @@ See the License for the specific language governing permissions and limitations under the License. --> <Root> + <TestCase name="testAggregateCaseToFilter"> + <Resource name="sql"> + <![CDATA[select + sum(sal) as sum_sal, + count(distinct case + when job = 'CLERK' + then deptno else null end) as count_distinct_clerk, + sum(case when deptno = 10 then sal end) as sum_sal_d10, + sum(case when deptno = 20 then sal else 0 end) as sum_sal_d20, + sum(case when deptno = 30 then 1 else 0 end) as count_d30, + count(case when deptno = 40 then 'x' end) as count_d40, + count(case when deptno = 20 then 1 end) as count_d20 +from emp]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], COUNT_DISTINCT_CLERK=[COUNT(DISTINCT $1)], SUM_SAL_D10=[SUM($2)], SUM_SAL_D20=[SUM($3)], COUNT_D30=[SUM($4)], COUNT_D40=[COUNT($5)], COUNT_D20=[COUNT($6)]) + LogicalProject(SAL=[$5], $f1=[CASE(=($2, 'CLERK'), $7, null:INTEGER)], $f2=[CASE(=($7, 10), $5, null:INTEGER)], $f3=[CASE(=($7, 20), $5, 0)], $f4=[CASE(=($7, 30), 1, 0)], $f5=[CASE(=($7, 40), 'x', null:CHAR(1))], $f6=[CASE(=($7, 20), 1, null:INTEGER)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + <Resource name="planAfter"> + <![CDATA[ +LogicalProject(SUM_SAL=[$0], COUNT_DISTINCT_CLERK=[$1], SUM_SAL_D10=[$2], SUM_SAL_D20=[$3], COUNT_D30=[CAST($4):INTEGER], COUNT_D40=[$5], COUNT_D20=[$6]) + LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], COUNT_DISTINCT_CLERK=[COUNT(DISTINCT $7) FILTER $8], SUM_SAL_D10=[SUM($9) FILTER $10], SUM_SAL_D20=[SUM($11) FILTER $12], COUNT_D30=[COUNT() FILTER $13], COUNT_D40=[COUNT() FILTER $14], COUNT_D20=[COUNT() FILTER $15]) + LogicalProject(SAL=[$5], $f1=[CASE(=($2, 'CLERK'), $7, null:INTEGER)], $f2=[CASE(=($7, 10), $5, null:INTEGER)], $f3=[CASE(=($7, 20), $5, 0)], $f4=[CASE(=($7, 30), 1, 0)], $f5=[CASE(=($7, 40), 'x', null:CHAR(1))], $f6=[CASE(=($7, 20), 1, null:INTEGER)], DEPTNO=[$7], $f8=[=($2, 'CLERK')], SAL0=[$5], $f10=[=($7, 10)], SAL1=[$5], $f12=[=($7, 20)], $f13=[=($7, 30)], $f14=[=($7, 40)], $f15=[=($7, 20)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + </TestCase> <TestCase name="testAggregateExtractProjectRule"> <Resource name="sql"> <![CDATA[select sum(sal) diff --git a/core/src/test/resources/sql/agg.iq b/core/src/test/resources/sql/agg.iq index 550e36f..3f939f7 100644 --- a/core/src/test/resources/sql/agg.iq +++ b/core/src/test/resources/sql/agg.iq @@ -2548,6 +2548,38 @@ select deptno, bit_and(empno), bit_or(empno) from "scott".emp group by deptno; !ok +# Based on [DRUID-7593] Exact distinct-COUNT with complex expression (CASE, IN) throws +# NullPointerException +WITH wikipedia AS ( + SELECT empno AS delta, + CASE WHEN deptno = 10 THEN 'true' ELSE 'false' END AS isRobot, + ename AS "user" + FROM "scott".emp) +SELECT COUNT(DISTINCT + CASE WHEN (((CASE WHEN wikipedia.delta IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + THEN REPLACE('Yes', 'Yes', 'Yes') + ELSE REPLACE('No', 'No', 'No') END) = 'No')) + AND (wikipedia.isRobot = 'true') + THEN (wikipedia."user") + ELSE NULL END) + - (MAX(CASE WHEN (((CASE WHEN wikipedia.delta IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + THEN REPLACE('Yes', 'Yes', 'Yes') + ELSE REPLACE('No', 'No', 'No') END) = 'No')) + AND (wikipedia.isRobot = 'true') + THEN NULL + ELSE -9223372036854775807 END) + + 9223372036854775807 + 1) AS "wikipedia.count_distinct_filters_that_dont_work" +FROM wikipedia +LIMIT 500; ++-------------------------------------------------+ +| wikipedia.count_distinct_filters_that_dont_work | ++-------------------------------------------------+ +| 2 | ++-------------------------------------------------+ +(1 row) + +!ok + # [CALCITE-2266] JSON_OBJECTAGG, JSON_ARRAYAGG !use post
