This is an automated email from the ASF dual-hosted git repository. jark pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 8a6877d14eb8f15c9f5124e400f2f3e351f42c13 Author: yuzhao.cyz <[email protected]> AuthorDate: Tue Mar 17 21:31:45 2020 +0800 [FLINK-14338][table-planner][table-planner-blink] Update files due to CALCITE-3763 * CALCITE-3763 prunes useless fields of input project --- .../table/planner/calcite/FlinkRelBuilder.scala | 20 +++- .../WindowAggregateReduceFunctionsRule.scala | 15 ++- .../physical/batch/BatchExecHashAggRule.scala | 13 ++- .../physical/batch/BatchExecSortAggRule.scala | 17 ++- .../batch/BatchExecWindowAggregateRule.scala | 18 ++- .../batch/RemoveRedundantLocalHashAggRule.scala | 4 +- .../batch/RemoveRedundantLocalSortAggRule.scala | 4 +- .../table/planner/plan/batch/table/CalcTest.xml | 12 +- .../FlinkAggregateJoinTransposeRuleTest.xml | 8 +- .../logical/SimplifyJoinConditionRuleTest.xml | 2 +- .../batch/RemoveRedundantLocalHashAggRuleTest.xml | 26 +++++ .../batch/RemoveRedundantLocalRankRuleTest.xml | 6 +- .../batch/RemoveRedundantLocalSortAggRuleTest.xml | 30 +++++ .../planner/plan/stream/table/AggregateTest.xml | 6 +- .../plan/stream/table/TableAggregateTest.xml | 6 +- .../plan/stream/table/TwoStageAggregateTest.xml | 6 +- .../RemoveRedundantLocalHashAggRuleTest.scala | 13 +++ .../RemoveRedundantLocalSortAggRuleTest.scala | 11 ++ .../ExtendedAggregateExtractProjectRule.java | 8 +- .../flink/table/calcite/FlinkRelBuilder.scala | 21 +++- .../rules/common/LogicalWindowAggregateRule.scala | 126 ++++++++++++++++++++- .../WindowAggregateReduceFunctionsRule.scala | 14 ++- 22 files changed, 337 insertions(+), 49 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala index 47f1ab8..3b8eb0e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala @@ -28,6 +28,7 @@ import org.apache.flink.table.planner.plan.nodes.calcite.{LogicalTableAggregate, import org.apache.flink.table.planner.plan.utils.AggregateUtil import org.apache.flink.table.runtime.operators.rank.{RankRange, RankType} import org.apache.flink.table.sinks.TableSink + import org.apache.calcite.plan._ import org.apache.calcite.rel.RelCollation import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField} @@ -40,6 +41,7 @@ import org.apache.calcite.util.{ImmutableBitSet, Util} import java.lang.Iterable import java.util import java.util.List +import java.util.function.UnaryOperator import scala.collection.JavaConversions._ @@ -132,7 +134,23 @@ class FlinkRelBuilder( namedProperties: List[PlannerNamedWindowProperty], aggCalls: Iterable[AggCall]): RelBuilder = { // build logical aggregate - val aggregate = super.aggregate(groupKey, aggCalls).build().asInstanceOf[LogicalAggregate] + + // Because of: + // [CALCITE-3763] RelBuilder.aggregate should prune unused fields from the input, + // if the input is a Project. + // + // the field can not be pruned if it is referenced by other expressions + // of the window aggregation(i.e. the TUMBLE_START/END). + // To solve this, we config the RelBuilder to forbidden this feature. + val aggregate = transform( + new UnaryOperator[RelBuilder.Config] { + override def apply(t: RelBuilder.Config) + : RelBuilder.Config = t.withPruneInputOfAggregate(false) + }) + .push(build()) + .aggregate(groupKey, aggCalls) + .build() + .asInstanceOf[LogicalAggregate] // build logical window aggregate from it aggregate match { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/WindowAggregateReduceFunctionsRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/WindowAggregateReduceFunctionsRule.scala index 90dbdc3..f6d7f9e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/WindowAggregateReduceFunctionsRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/WindowAggregateReduceFunctionsRule.scala @@ -20,6 +20,7 @@ package org.apache.flink.table.planner.plan.rules.logical import org.apache.flink.table.planner.plan.nodes.calcite.LogicalWindowAggregate +import org.apache.calcite.plan.Contexts import org.apache.calcite.plan.RelOptRule._ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.{Aggregate, AggregateCall, RelFactories} @@ -39,7 +40,11 @@ import scala.collection.JavaConversions._ class WindowAggregateReduceFunctionsRule extends AggregateReduceFunctionsRule( operand(classOf[LogicalWindowAggregate], any()), - RelFactories.LOGICAL_BUILDER) { + RelBuilder.proto( + Contexts.of( + RelFactories.DEFAULT_STRUCT, + RelBuilder.Config.DEFAULT + .withPruneInputOfAggregate(false)))) { override def newAggregateRel( relBuilder: RelBuilder, @@ -47,6 +52,14 @@ class WindowAggregateReduceFunctionsRule newCalls: util.List[AggregateCall]): Unit = { // create a LogicalAggregate with simpler aggregation functions + + // Because of: + // [CALCITE-3763] RelBuilder.aggregate should prune unused fields from the input, + // if the input is a Project. + // + // the field can not be pruned if it is referenced by other expressions + // of the window aggregation(i.e. the TUMBLE_START/END). + // To solve this, we config the RelBuilder to forbidden this feature. super.newAggregateRel(relBuilder, oldAgg, newCalls) // pop LogicalAggregate from RelBuilder val newAgg = relBuilder.build().asInstanceOf[LogicalAggregate] diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashAggRule.scala index 704d267..cc615f0 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashAggRule.scala @@ -118,6 +118,17 @@ class BatchExecHashAggRule } else { Seq(FlinkRelDistribution.SINGLETON) } + // Remove the global agg call filters because the + // filter is already done by local aggregation. + val aggCallsWithoutFilter = aggCallsWithoutAuxGroupCalls.map { + aggCall => + if (aggCall.filterArg > 0) { + aggCall.copy(aggCall.getArgList, -1, aggCall.getCollation) + } else { + aggCall + } + } + val globalAggCallToAggFunction = aggCallsWithoutFilter.zip(aggFunctions) globalDistributions.foreach { globalDistribution => val requiredTraitSet = localHashAgg.getTraitSet.replace(globalDistribution) val newLocalHashAgg = RelOptRule.convert(localHashAgg, requiredTraitSet) @@ -131,7 +142,7 @@ class BatchExecHashAggRule inputRowType, globalGroupSet, globalAuxGroupSet, - aggCallToAggFunction, + globalAggCallToAggFunction, isMerge = true) call.transformTo(globalHashAgg) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala index 426e9df..bf18c59 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala @@ -74,10 +74,6 @@ class BatchExecSortAggRule val input: RelNode = call.rel(1) val inputRowType = input.getRowType - if (agg.indicator) { - throw new UnsupportedOperationException("Not support group sets aggregate now.") - } - val (auxGroupSet, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg) val (_, aggBufferTypes, aggFunctions) = AggregateUtil.transformToBatchAggregateFunctions( @@ -124,6 +120,17 @@ class BatchExecSortAggRule } else { (Seq(FlinkRelDistribution.SINGLETON), RelCollations.EMPTY) } + // Remove the global agg call filters because the + // filter is already done by local aggregation. + val aggCallsWithoutFilter = aggCallsWithoutAuxGroupCalls.map { + aggCall => + if (aggCall.filterArg > 0) { + aggCall.copy(aggCall.getArgList, -1, aggCall.getCollation) + } else { + aggCall + } + } + val globalAggCallToAggFunction = aggCallsWithoutFilter.zip(aggFunctions) globalDistributions.foreach { globalDistribution => val requiredTraitSet = localSortAgg.getTraitSet .replace(globalDistribution) @@ -140,7 +147,7 @@ class BatchExecSortAggRule newLocalInput.getRowType, globalGroupSet, globalAuxGroupSet, - aggCallToAggFunction, + globalAggCallToAggFunction, isMerge = true) call.transformTo(globalSortAgg) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala index 7920a1a..bf430cb 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala @@ -34,12 +34,13 @@ import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDat import org.apache.flink.table.types.logical.{BigIntType, IntType, LogicalType} import org.apache.calcite.plan.RelOptRule._ -import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.plan.{Contexts, RelOptRule, RelOptRuleCall} import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.Aggregate.Group -import org.apache.calcite.rel.core.{Aggregate, AggregateCall} +import org.apache.calcite.rel.core.{Aggregate, AggregateCall, RelFactories} import org.apache.calcite.rel.{RelCollations, RelNode} import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.calcite.tools.RelBuilder import org.apache.commons.math3.util.ArithmeticUtils import scala.collection.JavaConversions._ @@ -70,6 +71,11 @@ class BatchExecWindowAggregateRule extends RelOptRule( operand(classOf[FlinkLogicalWindowAggregate], operand(classOf[RelNode], any)), + RelBuilder.proto( + Contexts.of( + RelFactories.DEFAULT_STRUCT, + RelBuilder.Config.DEFAULT + .withPruneInputOfAggregate(false))), "BatchExecWindowAggregateRule") with BatchExecAggRuleBase { @@ -156,6 +162,14 @@ class BatchExecWindowAggregateRule // TODO aggregate include projection now, so do not provide new trait will be safe val aggProvidedTraitSet = input.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL) + + // Because of: + // [CALCITE-3763] RelBuilder.aggregate should prune unused fields from the input, + // if the input is a Project. + // + // the field can not be pruned if it is referenced by other expressions + // of the window aggregation(i.e. the TUMBLE_START/END). + // To solve this, we config the RelBuilder to forbidden this feature. val inputTimeFieldIndex = AggregateUtil.timeFieldIndex( input.getRowType, call.builder(), window.timeAttribute) val inputTimeFieldType = agg.getInput.getRowType.getFieldList.get(inputTimeFieldIndex).getType diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala index c538dac..33a2bb2 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala @@ -49,7 +49,9 @@ class RemoveRedundantLocalHashAggRule extends RelOptRule( inputOfLocalAgg.getRowType, localAgg.getGrouping, localAgg.getAuxGrouping, - globalAgg.getAggCallToAggFunction, + // Use the localAgg agg calls because the global agg call filters was removed, + // see BatchExecHashAggRule for details. + localAgg.getAggCallToAggFunction, isMerge = false) call.transformTo(newGlobalAgg) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala index 615d082f..a0ff75e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala @@ -48,7 +48,9 @@ abstract class RemoveRedundantLocalSortAggRule( inputOfLocalAgg.getRowType, localAgg.getGrouping, localAgg.getAuxGrouping, - globalAgg.getAggCallToAggFunction, + // Use the localAgg agg calls because the global agg call filters was removed, + // see BatchExecSortAggRule for details. + localAgg.getAggCallToAggFunction, isMerge = false) call.transformTo(newGlobalAgg) } diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/table/CalcTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/table/CalcTest.xml index d44892b..4c11764 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/table/CalcTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/table/CalcTest.xml @@ -168,8 +168,8 @@ Calc(select=[a]) <Resource name="planBefore"> <![CDATA[ LogicalProject(EXPR$0=[$1]) -+- LogicalAggregate(group=[{4}], EXPR$0=[SUM($0)]) - +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], k=[org$apache$flink$table$planner$plan$batch$table$CalcTest$MyHashCode$$1945176195778b1bff1a30c41ce16445($2)]) ++- LogicalAggregate(group=[{1}], EXPR$0=[SUM($0)]) + +- LogicalProject(a=[$0], k=[org$apache$flink$table$planner$plan$batch$table$CalcTest$MyHashCode$$1945176195778b1bff1a30c41ce16445($2)]) +- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]]) ]]> </Resource> @@ -179,7 +179,7 @@ Calc(select=[EXPR$0]) +- HashAggregate(isMerge=[true], groupBy=[k], select=[k, Final_SUM(sum$0) AS EXPR$0]) +- Exchange(distribution=[hash[k]]) +- LocalHashAggregate(groupBy=[k], select=[k, Partial_SUM(a) AS sum$0]) - +- Calc(select=[a, b, c, d, MyHashCode$(c) AS k]) + +- Calc(select=[a, MyHashCode$(c) AS k]) +- TableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d]) ]]> </Resource> @@ -188,8 +188,8 @@ Calc(select=[EXPR$0]) <Resource name="planBefore"> <![CDATA[ LogicalProject(EXPR$0=[$1]) -+- LogicalAggregate(group=[{4}], EXPR$0=[SUM($0)]) - +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], k=[UPPER($2)]) ++- LogicalAggregate(group=[{1}], EXPR$0=[SUM($0)]) + +- LogicalProject(a=[$0], k=[UPPER($2)]) +- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]]) ]]> </Resource> @@ -199,7 +199,7 @@ Calc(select=[EXPR$0]) +- HashAggregate(isMerge=[true], groupBy=[k], select=[k, Final_SUM(sum$0) AS EXPR$0]) +- Exchange(distribution=[hash[k]]) +- LocalHashAggregate(groupBy=[k], select=[k, Partial_SUM(a) AS sum$0]) - +- Calc(select=[a, b, c, d, UPPER(c) AS k]) + +- Calc(select=[a, UPPER(c) AS k]) +- TableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d]) ]]> </Resource> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.xml index 0d58e8b..35d417a 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateJoinTransposeRuleTest.xml @@ -216,8 +216,8 @@ LogicalAggregate(group=[{}], EXPR$0=[COUNT($0)]) </Resource> <Resource name="planAfter"> <![CDATA[ -LogicalAggregate(group=[{}], EXPR$0=[$SUM0($4)]) -+- LogicalProject(a=[$0], $f1=[$1], a0=[$2], $f10=[$3], $f4=[*($1, $3)]) +LogicalAggregate(group=[{}], EXPR$0=[$SUM0($0)]) ++- LogicalProject($f4=[*($1, $3)]) +- LogicalJoin(condition=[=($0, $2)], joinType=[inner]) :- LogicalProject(a=[$0], $f1=[CASE(IS NOT NULL($0), 1:BIGINT, 0:BIGINT)]) : +- LogicalAggregate(group=[{0}]) @@ -244,8 +244,8 @@ LogicalAggregate(group=[{}], EXPR$0=[SUM($0)]) </Resource> <Resource name="planAfter"> <![CDATA[ -LogicalAggregate(group=[{}], EXPR$0=[SUM($3)]) -+- LogicalProject(a=[$0], a0=[$1], $f1=[$2], $f3=[CAST(*($0, $2)):INTEGER]) +LogicalAggregate(group=[{}], EXPR$0=[SUM($0)]) ++- LogicalProject($f3=[CAST(*($0, $2)):INTEGER]) +- LogicalJoin(condition=[=($0, $1)], joinType=[inner]) :- LogicalAggregate(group=[{0}]) : +- LogicalTableScan(table=[[default_catalog, default_database, T, source: [TestTableSource(a, b, c)]]]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyJoinConditionRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyJoinConditionRuleTest.xml index 6c65f74..182519d 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyJoinConditionRuleTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyJoinConditionRuleTest.xml @@ -42,7 +42,7 @@ LogicalProject(a=[$0]) +- LogicalJoin(condition=[AND(=($0, $3), =($1, $4))], joinType=[left]) :- LogicalTableScan(table=[[default_catalog, default_database, MyTable1, source: [TestTableSource(a, b, c)]]]) +- LogicalAggregate(group=[{0, 1}], EXPR$0=[COUNT()]) - +- LogicalProject(a=[$3], b=[$4], $f0=[0]) + +- LogicalProject(a=[$3], b=[$4]) +- LogicalJoin(condition=[AND(=($0, $3), OR(<($0, 2), =($4, 5)))], joinType=[inner]) :- LogicalTableScan(table=[[default_catalog, default_database, MyTable2, source: [TestTableSource(d, e, f)]]]) +- LogicalAggregate(group=[{0, 1}]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.xml index 8c16336..93897ab 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.xml @@ -81,4 +81,30 @@ HashAggregate(isMerge=[false], groupBy=[a], select=[a, SUM(b) AS EXPR$1]) ]]> </Resource> </TestCase> + <TestCase name="testUsingLocalAggCallFilters"> + <Resource name="sql"> + <![CDATA[SELECT d, MAX(e), MAX(e) FILTER (WHERE a < 10), COUNT(DISTINCT c), +COUNT(DISTINCT c) FILTER (WHERE a > 5), COUNT(DISTINCT b) FILTER (WHERE b > 3) +FROM z GROUP BY d]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalAggregate(group=[{0}], EXPR$1=[MAX($1)], EXPR$2=[MAX($1) FILTER $2], EXPR$3=[COUNT(DISTINCT $3)], EXPR$4=[COUNT(DISTINCT $3) FILTER $4], EXPR$5=[COUNT(DISTINCT $5) FILTER $6]) ++- LogicalProject(d=[$3], e=[$4], $f2=[IS TRUE(<($0, 10))], c=[$2], $f4=[IS TRUE(>($0, 5))], b=[$1], $f6=[IS TRUE(>($1, 3))]) + +- LogicalTableScan(table=[[default_catalog, default_database, z, source: [TestTableSource(a, b, c, d, e)]]]) +]]> + </Resource> + <Resource name="planAfter"> + <![CDATA[ +HashAggregate(isMerge=[false], groupBy=[d], select=[d, MIN(EXPR$1) FILTER $g_15 AS EXPR$1, MIN(EXPR$2) FILTER $g_15 AS EXPR$2, COUNT(c) FILTER $g_7 AS EXPR$3, COUNT(c) FILTER $g_3 AS EXPR$4, COUNT(b) FILTER $g_12 AS EXPR$5]) ++- Calc(select=[d, c, b, EXPR$1, EXPR$2, AND(=(CASE(=($e, 3:BIGINT), 3:BIGINT, =($e, 7:BIGINT), 7:BIGINT, =($e, 12:BIGINT), 12:BIGINT, 15:BIGINT), 3), $f4) AS $g_3, =(CASE(=($e, 3:BIGINT), 3:BIGINT, =($e, 7:BIGINT), 7:BIGINT, =($e, 12:BIGINT), 12:BIGINT, 15:BIGINT), 7) AS $g_7, AND(=(CASE(=($e, 3:BIGINT), 3:BIGINT, =($e, 7:BIGINT), 7:BIGINT, =($e, 12:BIGINT), 12:BIGINT, 15:BIGINT), 12), $f6) AS $g_12, =(CASE(=($e, 3:BIGINT), 3:BIGINT, =($e, 7:BIGINT), 7:BIGINT, =($e, 12:BIGINT), 12:BIGIN [...] + +- HashAggregate(isMerge=[true], groupBy=[d, c, $f4, b, $f6, $e], select=[d, c, $f4, b, $f6, $e, Final_MAX(max$0) AS EXPR$1, Final_MAX(max$1) AS EXPR$2]) + +- Exchange(distribution=[hash[d]]) + +- LocalHashAggregate(groupBy=[d, c, $f4, b, $f6, $e], select=[d, c, $f4, b, $f6, $e, Partial_MAX(e) AS max$0, Partial_MAX(e) FILTER $f2 AS max$1]) + +- Expand(projects=[d, e, $f2, c, $f4, b, $f6, $e], projects=[{d, e, $f2, c, $f4, null AS b, null AS $f6, 3 AS $e}, {d, e, $f2, c, null AS $f4, null AS b, null AS $f6, 7 AS $e}, {d, e, $f2, null AS c, null AS $f4, b, $f6, 12 AS $e}, {d, e, $f2, null AS c, null AS $f4, null AS b, null AS $f6, 15 AS $e}]) + +- Calc(select=[d, e, IS TRUE(<(a, 10)) AS $f2, c, IS TRUE(>(a, 5)) AS $f4, b, IS TRUE(>(b, 3)) AS $f6]) + +- TableSourceScan(table=[[default_catalog, default_database, z, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e]) +]]> + </Resource> + </TestCase> </Root> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalRankRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalRankRuleTest.xml index 4e7b5c5..df38440 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalRankRuleTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalRankRuleTest.xml @@ -100,9 +100,9 @@ LogicalProject(a=[$0], b=[$1], rk=[$2], rk1=[$3]) </Resource> <Resource name="planAfter"> <![CDATA[ -Rank(rankType=[RANK], rankRange=[rankStart=1, rankEnd=5], partitionBy=[a], orderBy=[b ASC], global=[true], select=[a, b, w0$o0, w0$o0]) -+- Calc(select=[a, b, w0$o0]) - +- Rank(rankType=[RANK], rankRange=[rankStart=1, rankEnd=5], partitionBy=[a], orderBy=[b ASC], global=[true], select=[a, b, c, w0$o0]) +Rank(rankType=[RANK], rankRange=[rankStart=1, rankEnd=5], partitionBy=[a], orderBy=[b ASC], global=[true], select=[a, b, $2, w0$o0]) ++- Calc(select=[a, b, $2]) + +- Rank(rankType=[RANK], rankRange=[rankStart=1, rankEnd=5], partitionBy=[a], orderBy=[b ASC], global=[true], select=[a, b, c, $2]) +- Sort(orderBy=[a ASC, b ASC]) +- Exchange(distribution=[hash[a]]) +- Rank(rankType=[RANK], rankRange=[rankStart=1, rankEnd=5], partitionBy=[a], orderBy=[b ASC], global=[false], select=[a, b, c]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.xml index 9a5c7f7..dc55eb4 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.xml @@ -83,4 +83,34 @@ Calc(select=[EXPR$0]) ]]> </Resource> </TestCase> + <TestCase name="testUsingLocalAggCallFilters"> + <Resource name="sql"> + <![CDATA[SELECT d, MAX(e), MAX(e) FILTER (WHERE a < 10), COUNT(DISTINCT c), +COUNT(DISTINCT c) FILTER (WHERE a > 5), COUNT(DISTINCT b) FILTER (WHERE b > 3) +FROM z GROUP BY d]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalAggregate(group=[{0}], EXPR$1=[MAX($1)], EXPR$2=[MAX($1) FILTER $2], EXPR$3=[COUNT(DISTINCT $3)], EXPR$4=[COUNT(DISTINCT $3) FILTER $4], EXPR$5=[COUNT(DISTINCT $5) FILTER $6]) ++- LogicalProject(d=[$3], e=[$4], $f2=[IS TRUE(<($0, 10))], c=[$2], $f4=[IS TRUE(>($0, 5))], b=[$1], $f6=[IS TRUE(>($1, 3))]) + +- LogicalTableScan(table=[[default_catalog, default_database, z, source: [TestTableSource(a, b, c, d, e)]]]) +]]> + </Resource> + <Resource name="planAfter"> + <![CDATA[ +SortAggregate(isMerge=[true], groupBy=[d], select=[d, Final_MIN(min$0) AS EXPR$1, Final_MIN(min$1) AS EXPR$2, Final_COUNT(count$2) AS EXPR$3, Final_COUNT(count$3) AS EXPR$4, Final_COUNT(count$4) AS EXPR$5]) ++- Sort(orderBy=[d ASC]) + +- Exchange(distribution=[hash[d]]) + +- LocalSortAggregate(groupBy=[d], select=[d, Partial_MIN(EXPR$1) FILTER $g_15 AS min$0, Partial_MIN(EXPR$2) FILTER $g_15 AS min$1, Partial_COUNT(c) FILTER $g_7 AS count$2, Partial_COUNT(c) FILTER $g_3 AS count$3, Partial_COUNT(b) FILTER $g_12 AS count$4]) + +- Calc(select=[d, c, b, EXPR$1, EXPR$2, AND(=(CASE(=($e, 3:BIGINT), 3:BIGINT, =($e, 7:BIGINT), 7:BIGINT, =($e, 12:BIGINT), 12:BIGINT, 15:BIGINT), 3), $f4) AS $g_3, =(CASE(=($e, 3:BIGINT), 3:BIGINT, =($e, 7:BIGINT), 7:BIGINT, =($e, 12:BIGINT), 12:BIGINT, 15:BIGINT), 7) AS $g_7, AND(=(CASE(=($e, 3:BIGINT), 3:BIGINT, =($e, 7:BIGINT), 7:BIGINT, =($e, 12:BIGINT), 12:BIGINT, 15:BIGINT), 12), $f6) AS $g_12, =(CASE(=($e, 3:BIGINT), 3:BIGINT, =($e, 7:BIGINT), 7:BIGINT, =($e, 12:BIGINT), [...] + +- Sort(orderBy=[d ASC]) + +- SortAggregate(isMerge=[false], groupBy=[d, c, $f4, b, $f6, $e], select=[d, c, $f4, b, $f6, $e, MAX(e) AS EXPR$1, MAX(e) FILTER $f2 AS EXPR$2]) + +- Sort(orderBy=[d ASC, c ASC, $f4 ASC, b ASC, $f6 ASC, $e ASC]) + +- Exchange(distribution=[hash[d, c, $f4, b, $f6, $e]]) + +- Expand(projects=[d, e, $f2, c, $f4, b, $f6, $e], projects=[{d, e, $f2, c, $f4, null AS b, null AS $f6, 3 AS $e}, {d, e, $f2, c, null AS $f4, null AS b, null AS $f6, 7 AS $e}, {d, e, $f2, null AS c, null AS $f4, b, $f6, 12 AS $e}, {d, e, $f2, null AS c, null AS $f4, null AS b, null AS $f6, 15 AS $e}]) + +- Calc(select=[d, e, IS TRUE(<(a, 10)) AS $f2, c, IS TRUE(>(a, 5)) AS $f4, b, IS TRUE(>(b, 3)) AS $f6]) + +- TableSourceScan(table=[[default_catalog, default_database, z, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e]) +]]> + </Resource> + </TestCase> </Root> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/AggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/AggregateTest.xml index b7c4a6e..fd6047b 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/AggregateTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/AggregateTest.xml @@ -102,8 +102,8 @@ Calc(select=[EXPR$0]) <Resource name="planBefore"> <![CDATA[ LogicalProject(b=[$0], EXPR$0=[$1]) -+- LogicalAggregate(group=[{1}], EXPR$0=[AVG($3)]) - +- LogicalProject(a=[$0], b=[$1], c=[$2], a0=[CAST($0):DOUBLE]) ++- LogicalAggregate(group=[{0}], EXPR$0=[AVG($1)]) + +- LogicalProject(b=[$1], a0=[CAST($0):DOUBLE]) +- LogicalTableScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c)]]]) ]]> </Resource> @@ -111,7 +111,7 @@ LogicalProject(b=[$0], EXPR$0=[$1]) <![CDATA[ GroupAggregate(groupBy=[b], select=[b, AVG(a0) AS EXPR$0]) +- Exchange(distribution=[hash[b]]) - +- Calc(select=[a, b, c, CAST(a) AS a0]) + +- Calc(select=[b, CAST(a) AS a0]) +- TableSourceScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) ]]> </Resource> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TableAggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TableAggregateTest.xml index 5cbc8cf..1cd0641 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TableAggregateTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TableAggregateTest.xml @@ -53,8 +53,8 @@ Calc(select=[f0 AS a, f1 AS b]) <Resource name="planBefore"> <![CDATA[ LogicalProject(bb=[AS($0, _UTF-16LE'bb')], _c1=[+(AS($1, _UTF-16LE'x'), 1)], y=[AS($2, _UTF-16LE'y')]) -+- LogicalTableAggregate(group=[{5}], tableAggregate=[[EmptyTableAggFunc($0, $1)]]) - +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], e=[$4], bb=[MOD($1, 5)]) ++- LogicalTableAggregate(group=[{2}], tableAggregate=[[EmptyTableAggFunc($0, $1)]]) + +- LogicalProject(a=[$0], b=[$1], bb=[MOD($1, 5)]) +- LogicalTableScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c, d, e)]]]) ]]> </Resource> @@ -63,7 +63,7 @@ LogicalProject(bb=[AS($0, _UTF-16LE'bb')], _c1=[+(AS($1, _UTF-16LE'x'), 1)], y=[ Calc(select=[bb, +(f0, 1) AS _c1, f1 AS y]) +- GroupTableAggregate(groupBy=[bb], select=[bb, EmptyTableAggFunc(a, b) AS (f0, f1)]) +- Exchange(distribution=[hash[bb]]) - +- Calc(select=[a, b, c, d, e, MOD(b, 5) AS bb]) + +- Calc(select=[a, b, MOD(b, 5) AS bb]) +- TableSourceScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e]) ]]> </Resource> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TwoStageAggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TwoStageAggregateTest.xml index be5a7e3..8b59588 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TwoStageAggregateTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TwoStageAggregateTest.xml @@ -39,8 +39,8 @@ Calc(select=[EXPR$0]) <Resource name="planBefore"> <![CDATA[ LogicalProject(b=[$0], EXPR$0=[$1]) -+- LogicalAggregate(group=[{1}], EXPR$0=[AVG($3)]) - +- LogicalProject(a=[$0], b=[$1], c=[$2], a0=[CAST($0):DOUBLE]) ++- LogicalAggregate(group=[{0}], EXPR$0=[AVG($1)]) + +- LogicalProject(b=[$1], a0=[CAST($0):DOUBLE]) +- LogicalTableScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c)]]]) ]]> </Resource> @@ -49,7 +49,7 @@ LogicalProject(b=[$0], EXPR$0=[$1]) GlobalGroupAggregate(groupBy=[b], select=[b, AVG((sum$0, count$1)) AS EXPR$0]) +- Exchange(distribution=[hash[b]]) +- LocalGroupAggregate(groupBy=[b], select=[b, AVG(a0) AS (sum$0, count$1)]) - +- Calc(select=[a, b, c, CAST(a) AS a0]) + +- Calc(select=[b, CAST(a) AS a0]) +- MiniBatchAssigner(interval=[1000ms], mode=[ProcTime]) +- TableSourceScan(table=[[default_catalog, default_database, Table1, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) ]]> diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.scala index cad9754..d155b5c 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.scala @@ -35,6 +35,7 @@ class RemoveRedundantLocalHashAggRuleTest extends TableTestBase { def setup(): Unit = { util.addTableSource[(Int, Long, String)]("x", 'a, 'b, 'c) util.addTableSource[(Int, Long, String)]("y", 'd, 'e, 'f) + util.addTableSource[(Int, Long, Long, Long, Long)]("z", 'a, 'b, 'c, 'd, 'e) } @Test @@ -69,4 +70,16 @@ class RemoveRedundantLocalHashAggRuleTest extends TableTestBase { util.verifyPlan(sqlQuery) } + @Test + def testUsingLocalAggCallFilters(): Unit = { + util.tableEnv.getConfig.getConfiguration.setString( + ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "SortAgg") + util.tableEnv.getConfig.getConfiguration.setBoolean( + BatchExecJoinRuleBase.TABLE_OPTIMIZER_SHUFFLE_BY_PARTIAL_KEY_ENABLED, true) + val sqlQuery = "SELECT d, MAX(e), MAX(e) FILTER (WHERE a < 10), COUNT(DISTINCT c),\n" + + "COUNT(DISTINCT c) FILTER (WHERE a > 5), COUNT(DISTINCT b) FILTER (WHERE b > 3)\n" + + "FROM z GROUP BY d" + util.verifyPlan(sqlQuery) + } + } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.scala index a7f72c9..2957254 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.scala @@ -35,6 +35,7 @@ class RemoveRedundantLocalSortAggRuleTest extends TableTestBase { def setup(): Unit = { util.addTableSource[(Int, Long, String)]("x", 'a, 'b, 'c) util.addTableSource[(Int, Long, String)]("y", 'd, 'e, 'f) + util.addTableSource[(Int, Long, Long, Long, Long)]("z", 'a, 'b, 'c, 'd, 'e) } @Test @@ -64,4 +65,14 @@ class RemoveRedundantLocalSortAggRuleTest extends TableTestBase { util.verifyPlan(sqlQuery) } + @Test + def testUsingLocalAggCallFilters(): Unit = { + util.tableEnv.getConfig.getConfiguration.setString( + ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "HashAgg") + val sqlQuery = "SELECT d, MAX(e), MAX(e) FILTER (WHERE a < 10), COUNT(DISTINCT c),\n" + + "COUNT(DISTINCT c) FILTER (WHERE a > 5), COUNT(DISTINCT b) FILTER (WHERE b > 3)\n" + + "FROM z GROUP BY d" + util.verifyPlan(sqlQuery) + } + } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/ExtendedAggregateExtractProjectRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/ExtendedAggregateExtractProjectRule.java index 80c297e..ba703a4 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/ExtendedAggregateExtractProjectRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/ExtendedAggregateExtractProjectRule.java @@ -25,6 +25,7 @@ import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate; import org.apache.flink.table.plan.logical.rel.LogicalWindowTableAggregate; import org.apache.flink.table.plan.logical.rel.TableAggregate; +import org.apache.calcite.plan.Contexts; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelOptRuleOperand; import org.apache.calcite.rel.RelNode; @@ -63,7 +64,12 @@ public class ExtendedAggregateExtractProjectRule extends AggregateExtractProject public static final ExtendedAggregateExtractProjectRule INSTANCE = new ExtendedAggregateExtractProjectRule( operand(SingleRel.class, - operand(RelNode.class, any())), RelFactories.LOGICAL_BUILDER); + operand(RelNode.class, any())), + RelBuilder.proto( + Contexts.of( + RelFactories.DEFAULT_STRUCT, + RelBuilder.Config.DEFAULT + .withPruneInputOfAggregate(false)))); public ExtendedAggregateExtractProjectRule( RelOptRuleOperand operand, diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala index 75a57b3..838c99a 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala @@ -20,7 +20,6 @@ package org.apache.flink.table.calcite import java.lang.Iterable import java.util.{List => JList} - import org.apache.calcite.plan._ import org.apache.calcite.rel.logical.LogicalAggregate import org.apache.calcite.tools.RelBuilder @@ -33,6 +32,8 @@ import org.apache.flink.table.plan.logical.LogicalWindow import org.apache.flink.table.plan.logical.rel.{LogicalTableAggregate, LogicalWindowAggregate, LogicalWindowTableAggregate} import org.apache.flink.table.runtime.aggregate.AggregateUtil +import java.util.function.UnaryOperator + import scala.collection.JavaConverters._ /** @@ -86,7 +87,23 @@ class FlinkRelBuilder( aggCalls: Iterable[AggCall]) : RelBuilder = { // build logical aggregate - val aggregate = super.aggregate(groupKey, aggCalls).build().asInstanceOf[LogicalAggregate] + + // Because of: + // [CALCITE-3763] RelBuilder.aggregate should prune unused fields + // from the input, if the input is a Project. + // + // the field can not be pruned if it is referenced by other expressions + // of the window aggregation(i.e. the TUMBLE_START/END). + // To solve this, we config the RelBuilder to forbidden this feature. + val aggregate = transform( + new UnaryOperator[RelBuilder.Config] { + override def apply(t: RelBuilder.Config) + : RelBuilder.Config = t.withPruneInputOfAggregate(false) + }) + .push(build()) + .aggregate(groupKey, aggCalls) + .build() + .asInstanceOf[LogicalAggregate] val namedProperties = windowProperties.asScala.map { case Alias(p: WindowProperty, name, _) => diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala index 431fe9e..a866b65 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala @@ -21,23 +21,33 @@ import com.google.common.collect.ImmutableList import org.apache.calcite.plan._ import org.apache.calcite.plan.hep.HepRelVertex import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rel.core.{Aggregate, AggregateCall} +import org.apache.calcite.rel.core.{Aggregate, AggregateCall, Project, RelFactories} import org.apache.calcite.rel.logical.{LogicalAggregate, LogicalProject} import org.apache.calcite.rex._ import org.apache.calcite.sql.`type`.SqlTypeUtil import org.apache.calcite.util.ImmutableBitSet import org.apache.flink.table.api._ import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty +import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.catalog.BasicOperatorTable import org.apache.flink.table.plan.logical.LogicalWindow import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate +import org.apache.calcite.rel.RelNode +import org.apache.calcite.tools.RelBuilder + +import _root_.java.util.{ArrayList => JArrayList, Collections, List => JList} import _root_.scala.collection.JavaConversions._ abstract class LogicalWindowAggregateRule(ruleName: String) extends RelOptRule( RelOptRule.operand(classOf[LogicalAggregate], RelOptRule.operand(classOf[LogicalProject], RelOptRule.none())), + RelBuilder.proto( + Contexts.of( + RelFactories.DEFAULT_STRUCT, + RelBuilder.Config.DEFAULT + .withBloat(-1))), ruleName) { override def matches(call: RelOptRuleCall): Boolean = { @@ -50,7 +60,7 @@ abstract class LogicalWindowAggregateRule(ruleName: String) throw new TableException("Only a single window group function may be used in GROUP BY") } - !groupSets && !agg.indicator && windowExpressions.nonEmpty + !groupSets && windowExpressions.nonEmpty } /** @@ -61,8 +71,15 @@ abstract class LogicalWindowAggregateRule(ruleName: String) * that the types are equivalent. */ override def onMatch(call: RelOptRuleCall): Unit = { - val agg = call.rel[LogicalAggregate](0) - val project = agg.getInput.asInstanceOf[HepRelVertex].getCurrentRel.asInstanceOf[LogicalProject] + val agg0 = call.rel[LogicalAggregate](0) + val project0 = call.rel[LogicalProject](1) + val project = rewriteWindowCallWithFuncOperands(project0, call.builder()) + val agg = if (project != project0) { + agg0.copy(agg0.getTraitSet, Collections.singletonList(project)) + .asInstanceOf[LogicalAggregate] + } else { + agg0 + } val (windowExpr, windowExprIdx) = getWindowExpressions(agg).head val window = translateWindowExpression(windowExpr, project.getInput.getRowType) @@ -90,7 +107,6 @@ abstract class LogicalWindowAggregateRule(ruleName: String) // we don't use the builder here because it uses RelMetadataQuery which affects the plan val newAgg = LogicalAggregate.create( newProject, - agg.indicator, newGroupSet, ImmutableList.of(newGroupSet), finalCalls) @@ -125,6 +141,104 @@ abstract class LogicalWindowAggregateRule(ruleName: String) call.transformTo(result) } + /** Trim out the HepRelVertex wrapper and get current relational expression. */ + private def trimHep(node: RelNode): RelNode = { + node match { + case hepRelVertex: HepRelVertex => + hepRelVertex.getCurrentRel + case _ => node + } + } + + /** + * Rewrite plan with function call as window call operand: rewrite the window call to + * reference the input instead of invoking the function directly, in order to simplify the + * subsequent rewrite logic. + * + * For example, plan + * <pre> + * LogicalAggregate(group=[{0}], a=[COUNT()]) + * LogicalProject($f0=[$TUMBLE(TUMBLE_ROWTIME($0), 4:INTERVAL SECOND)], a=[$1]) + * LogicalProject($f0=[1970-01-01 00:00:00:TIMESTAMP(3)], a=[$0]) + * </pre> + * + * would be rewritten to + * <pre> + * LogicalAggregate(group=[{0}], a=[COUNT()]) + * LogicalProject($f0=[TUMBLE($1, 4:INTERVAL SECOND)], a=[$0]) + * LogicalProject(a=[$1], zzzzz=[TUMBLE_ROWTIME($0)]) + * LogicalProject($f0=[1970-01-01 00:00:00:TIMESTAMP(3)], a=[$0]) + * </pre> + */ + private def rewriteWindowCallWithFuncOperands( + project: LogicalProject, + relBuilder: RelBuilder): LogicalProject = { + val projectInput = trimHep(project.getInput) + if (!projectInput.isInstanceOf[Project]) { + return project + } + val inputProjects = projectInput.asInstanceOf[Project].getChildExps + var hasWindowCallWithFuncOperands: Boolean = false + var lastIdx = projectInput.getRowType.getFieldCount - 1; + val pushDownCalls = new JArrayList[RexNode]() + 0 until projectInput.getRowType.getFieldCount foreach { + idx => pushDownCalls.add(RexInputRef.of(idx, projectInput.getRowType)) + } + val newProjectExprs = project.getChildExps.map { + case call: RexCall if isWindowCall(call) && + isTimeAttributeCall(call.getOperands.head, inputProjects) => + hasWindowCallWithFuncOperands = true + // Update the window call to reference a RexInputRef instead of a function call. + call.accept( + new RexShuttle { + override def visitCall(call: RexCall): RexNode = { + if (isTimeAttributeCall(call, inputProjects)) { + lastIdx += 1 + pushDownCalls.add(call) + relBuilder.getRexBuilder.makeInputRef( + call.getType, + // We would project plus an additional function call + // at the end of input projection. + lastIdx) + } else { + super.visitCall(call) + } + } + }) + case rex: RexNode => rex + } + + if (hasWindowCallWithFuncOperands) { + relBuilder + .push(projectInput) + // project plus the function call. + .project(pushDownCalls) + .project(newProjectExprs, project.getRowType.getFieldNames) + .build() + .asInstanceOf[LogicalProject] + } else { + project + } + } + + /** Decides if the [[RexNode]] is a call whose return type is + * a time indicator type. */ + def isTimeAttributeCall(rexNode: RexNode, projects: JList[RexNode]): Boolean = rexNode match { + case call: RexCall if FlinkTypeFactory.isTimeIndicatorType(call.getType) => + call.getOperands.forall { operand => + operand.isInstanceOf[RexInputRef] + } + case _ => false + } + + /** Decides whether the [[RexCall]] is a window call. */ + def isWindowCall(call: RexCall): Boolean = call.getOperator match { + case BasicOperatorTable.SESSION | + BasicOperatorTable.HOP | + BasicOperatorTable.TUMBLE => true + case _ => false + } + /** * Change the types of [[AggregateCall]] to the corresponding inferred types. */ @@ -182,7 +296,7 @@ abstract class LogicalWindowAggregateRule(ruleName: String) private[table] def getWindowExpressions(agg: LogicalAggregate): Seq[(RexCall, Int)] = { - val project = agg.getInput.asInstanceOf[HepRelVertex].getCurrentRel.asInstanceOf[LogicalProject] + val project = trimHep(agg.getInput).asInstanceOf[LogicalProject] val groupKeys = agg.getGroupSet // get grouping expressions diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/WindowAggregateReduceFunctionsRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/WindowAggregateReduceFunctionsRule.scala index 50c0758..38bcaf5 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/WindowAggregateReduceFunctionsRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/WindowAggregateReduceFunctionsRule.scala @@ -19,8 +19,7 @@ package org.apache.flink.table.plan.rules.common import java.util - -import org.apache.calcite.plan.RelOptRule +import org.apache.calcite.plan.{Contexts, RelOptRule} import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.{Aggregate, AggregateCall, RelFactories} import org.apache.calcite.rel.logical.LogicalAggregate @@ -35,9 +34,14 @@ import scala.collection.JavaConversions._ * Rule to convert complex aggregation functions into simpler ones. * Have a look at [[AggregateReduceFunctionsRule]] for details. */ -class WindowAggregateReduceFunctionsRule extends AggregateReduceFunctionsRule( - RelOptRule.operand(classOf[LogicalWindowAggregate], RelOptRule.any()), - RelFactories.LOGICAL_BUILDER) { +class WindowAggregateReduceFunctionsRule + extends AggregateReduceFunctionsRule( + RelOptRule.operand(classOf[LogicalWindowAggregate], RelOptRule.any()), + RelBuilder.proto( + Contexts.of( + RelFactories.DEFAULT_STRUCT, + RelBuilder.Config.DEFAULT + .withPruneInputOfAggregate(false)))) { override def newAggregateRel( relBuilder: RelBuilder,
