This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit fa6412793ec1c11b751a51abc0c1492bf3079573 Author: godfreyhe <[email protected]> AuthorDate: Wed Dec 23 15:30:16 2020 +0800 [FLINK-20737][table-planner-blink] Introduce StreamPhysicalLocalGroupAggregate, and make StreamExecLocalGroupAggregate only extended from ExecNode This closes #14478 --- .../exec/stream/StreamExecLocalGroupAggregate.java | 129 +++++++++++++++++++ .../plan/metadata/FlinkRelMdColumnInterval.scala | 9 +- .../plan/metadata/FlinkRelMdColumnUniqueness.scala | 2 +- .../FlinkRelMdFilteredColumnInterval.scala | 4 +- .../metadata/FlinkRelMdModifiedMonotonicity.scala | 2 +- .../plan/metadata/FlinkRelMdUniqueKeys.scala | 2 +- .../StreamExecIncrementalGroupAggregate.scala | 6 +- .../stream/StreamExecLocalGroupAggregate.scala | 142 --------------------- .../stream/StreamPhysicalLocalGroupAggregate.scala | 103 +++++++++++++++ .../physical/stream/IncrementalAggregateRule.scala | 12 +- .../stream/TwoStageOptimizedAggregateRule.scala | 48 +++---- .../plan/metadata/FlinkRelMdHandlerTestBase.scala | 24 ++-- 12 files changed, 286 insertions(+), 197 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalGroupAggregate.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalGroupAggregate.java new file mode 100644 index 0000000..26aa171 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalGroupAggregate.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.exec.stream; + +import org.apache.flink.api.dag.Transformation; +import org.apache.flink.streaming.api.transformations.OneInputTransformation; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.planner.codegen.CodeGeneratorContext; +import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator; +import org.apache.flink.table.planner.delegation.PlannerBase; +import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNode; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase; +import org.apache.flink.table.planner.plan.utils.AggregateInfoList; +import org.apache.flink.table.planner.plan.utils.AggregateUtil; +import org.apache.flink.table.planner.plan.utils.KeySelectorUtil; +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; +import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction; +import org.apache.flink.table.runtime.keyselector.RowDataKeySelector; +import org.apache.flink.table.runtime.operators.aggregate.MiniBatchLocalGroupAggFunction; +import org.apache.flink.table.runtime.operators.bundle.MapBundleOperator; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.types.logical.RowType; + +import org.apache.calcite.rel.core.AggregateCall; + +import java.util.Arrays; +import java.util.Collections; + +/** Stream {@link ExecNode} for unbounded local group aggregate. */ +public class StreamExecLocalGroupAggregate extends ExecNodeBase<RowData> + implements StreamExecNode<RowData> { + + private final int[] grouping; + private final AggregateCall[] aggCalls; + /** Each element indicates whether the corresponding agg call needs `retract` method. */ + private final boolean[] aggCallNeedRetractions; + /** Whether this node consumes retraction messages. */ + private final boolean needRetraction; + + public StreamExecLocalGroupAggregate( + int[] grouping, + AggregateCall[] aggCalls, + boolean[] aggCallNeedRetractions, + boolean needRetraction, + ExecEdge inputEdge, + RowType outputType, + String description) { + super(Collections.singletonList(inputEdge), outputType, description); + this.grouping = grouping; + this.aggCalls = aggCalls; + this.aggCallNeedRetractions = aggCallNeedRetractions; + this.needRetraction = needRetraction; + } + + @SuppressWarnings("unchecked") + @Override + protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) { + final ExecNode<RowData> inputNode = (ExecNode<RowData>) getInputNodes().get(0); + final Transformation<RowData> inputTransform = inputNode.translateToPlan(planner); + final RowType inputRowType = (RowType) inputNode.getOutputType(); + + final AggsHandlerCodeGenerator generator = + new AggsHandlerCodeGenerator( + new CodeGeneratorContext(planner.getTableConfig()), + planner.getRelBuilder(), + JavaScalaConversionUtil.toScala(inputRowType.getChildren()), + // the local aggregate result will be buffered, so need copy + true); + generator.needAccumulate().needMerge(0, true, null); + if (needRetraction) { + generator.needRetract(); + } + + final AggregateInfoList aggInfoList = + AggregateUtil.transformToStreamAggregateInfoList( + inputRowType, + JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)), + aggCallNeedRetractions, + needRetraction, + false, + true); + final GeneratedAggsHandleFunction aggsHandler = + generator.generateAggsHandler("GroupAggsHandler", aggInfoList); + final MiniBatchLocalGroupAggFunction aggFunction = + new MiniBatchLocalGroupAggFunction(aggsHandler); + + final RowDataKeySelector selector = + KeySelectorUtil.getRowDataSelector( + grouping, (InternalTypeInfo<RowData>) inputTransform.getOutputType()); + + final MapBundleOperator<RowData, RowData, RowData, RowData> operator = + new MapBundleOperator<>( + aggFunction, + AggregateUtil.createMiniBatchTrigger(planner.getTableConfig()), + selector); + + final OneInputTransformation<RowData, RowData> transform = + new OneInputTransformation<>( + inputTransform, + getDesc(), + operator, + InternalTypeInfo.of(getOutputType()), + inputTransform.getParallelism()); + + if (inputsContainSingleton()) { + transform.setParallelism(1); + transform.setMaxParallelism(1); + } + + return transform; + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala index 2e33587..9978208 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala @@ -459,7 +459,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { * @return interval of the given column on stream local group Aggregate */ def getColumnInterval( - aggregate: StreamExecLocalGroupAggregate, + aggregate: StreamPhysicalLocalGroupAggregate, mq: RelMetadataQuery, index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index) @@ -536,7 +536,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq) val groupSet = aggregate match { case agg: StreamPhysicalGroupAggregate => agg.grouping - case agg: StreamExecLocalGroupAggregate => agg.grouping + case agg: StreamPhysicalLocalGroupAggregate => agg.grouping case agg: StreamExecGlobalGroupAggregate => agg.grouping case agg: StreamExecIncrementalGroupAggregate => agg.partialAggGrouping case agg: StreamExecGroupWindowAggregate => agg.getGrouping @@ -606,9 +606,8 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } else { null } - case agg: StreamExecLocalGroupAggregate => - getAggCallFromLocalAgg( - aggCallIndex, agg.aggInfoList.getActualAggregateCalls, agg.getInput.getRowType) + case agg: StreamPhysicalLocalGroupAggregate => + getAggCallFromLocalAgg(aggCallIndex, agg.aggCalls, agg.getInput.getRowType) case agg: StreamExecIncrementalGroupAggregate if agg.partialAggInfoList.getActualAggregateCalls.length > aggCallIndex => agg.partialAggInfoList.getActualAggregateCalls(aggCallIndex) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala index 5799ab0..13954d3 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala @@ -357,7 +357,7 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata } def areColumnsUnique( - rel: StreamExecLocalGroupAggregate, + rel: StreamPhysicalLocalGroupAggregate, mq: RelMetadataQuery, columns: ImmutableBitSet, ignoreNulls: Boolean): JBoolean = null diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala index c851c96..5444b3b 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala @@ -20,7 +20,7 @@ package org.apache.flink.table.planner.plan.metadata import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.FilteredColumnInterval import org.apache.flink.table.planner.plan.nodes.calcite.TableAggregate import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecGroupAggregateBase -import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGlobalGroupAggregate, StreamPhysicalGroupAggregate, StreamExecGroupTableAggregate, StreamExecGroupWindowAggregate, StreamExecGroupWindowTableAggregate, StreamExecLocalGroupAggregate} +import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGlobalGroupAggregate, StreamExecGroupTableAggregate, StreamExecGroupWindowAggregate, StreamExecGroupWindowTableAggregate, StreamPhysicalGroupAggregate, StreamPhysicalLocalGroupAggregate} import org.apache.flink.table.planner.plan.stats.ValueInterval import org.apache.flink.table.planner.plan.utils.ColumnIntervalUtil import org.apache.flink.util.Preconditions.checkArgument @@ -200,7 +200,7 @@ class FlinkRelMdFilteredColumnInterval private extends MetadataHandler[FilteredC } def getFilteredColumnInterval( - aggregate: StreamExecLocalGroupAggregate, + aggregate: StreamPhysicalLocalGroupAggregate, mq: RelMetadataQuery, columnIndex: Int, filterArg: Int): ValueInterval = { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala index d46ac17..1f2e1ee 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala @@ -297,7 +297,7 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon } def getRelModifiedMonotonicity( - rel: StreamExecLocalGroupAggregate, + rel: StreamPhysicalLocalGroupAggregate, mq: RelMetadataQuery): RelModifiedMonotonicity = { getRelModifiedMonotonicityOnAggregate(rel.getInput, mq, rel.aggCalls.toList, rel.grouping) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala index bf9c1eb..96c5a9b 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala @@ -352,7 +352,7 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu } def getUniqueKeys( - rel: StreamExecLocalGroupAggregate, + rel: StreamPhysicalLocalGroupAggregate, mq: RelMetadataQuery, ignoreNulls: Boolean): JSet[ImmutableBitSet] = null diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala index 9706253..556c4fc 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala @@ -50,10 +50,10 @@ import scala.collection.JavaConversions._ * {{{ * StreamExecGlobalGroupAggregate (final-global-aggregate) * +- StreamPhysicalExchange - * +- StreamExecLocalGroupAggregate (final-local-aggregate) + * +- StreamPhysicalLocalGroupAggregate (final-local-aggregate) * +- StreamExecGlobalGroupAggregate (partial-global-aggregate) * +- StreamPhysicalExchange - * +- StreamExecLocalGroupAggregate (partial-local-aggregate) + * +- StreamPhysicalLocalGroupAggregate (partial-local-aggregate) * }}} * * partial-global-aggregate and final-local-aggregate can be combined as @@ -64,7 +64,7 @@ import scala.collection.JavaConversions._ * +- StreamPhysicalExchange * +- StreamExecIncrementalGroupAggregate * +- StreamPhysicalExchange - * +- StreamExecLocalGroupAggregate (partial-local-aggregate) + * +- StreamPhysicalLocalGroupAggregate (partial-local-aggregate) * }}} * * @see [[StreamPhysicalGroupAggregateBase]] for more info. diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecLocalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecLocalGroupAggregate.scala deleted file mode 100644 index 730695c..0000000 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecLocalGroupAggregate.scala +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.planner.plan.nodes.physical.stream - -import org.apache.flink.api.dag.Transformation -import org.apache.flink.api.java.functions.KeySelector -import org.apache.flink.streaming.api.transformations.OneInputTransformation -import org.apache.flink.table.data.RowData -import org.apache.flink.table.planner.calcite.FlinkTypeFactory -import org.apache.flink.table.planner.codegen.CodeGeneratorContext -import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator -import org.apache.flink.table.planner.delegation.StreamPlanner -import org.apache.flink.table.planner.plan.PartialFinalType -import org.apache.flink.table.planner.plan.nodes.exec.LegacyStreamExecNode -import org.apache.flink.table.planner.plan.utils.{KeySelectorUtil, _} -import org.apache.flink.table.runtime.operators.aggregate.MiniBatchLocalGroupAggFunction -import org.apache.flink.table.runtime.operators.bundle.MapBundleOperator -import org.apache.flink.table.runtime.typeutils.InternalTypeInfo - -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rel.core.AggregateCall -import org.apache.calcite.rel.{RelNode, RelWriter} - -import java.util - -import scala.collection.JavaConversions._ - -/** - * Stream physical RelNode for unbounded local group aggregate. - * - * @see [[StreamPhysicalGroupAggregateBase]] for more info. - */ -class StreamExecLocalGroupAggregate( - cluster: RelOptCluster, - traitSet: RelTraitSet, - inputRel: RelNode, - outputRowType: RelDataType, - val grouping: Array[Int], - val aggCalls: Seq[AggregateCall], - val aggInfoList: AggregateInfoList, - val partialFinalType: PartialFinalType) - extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) - with LegacyStreamExecNode[RowData] { - - override def requireWatermark: Boolean = false - - override def deriveRowType(): RelDataType = outputRowType - - override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { - new StreamExecLocalGroupAggregate( - cluster, - traitSet, - inputs.get(0), - outputRowType, - grouping, - aggCalls, - aggInfoList, - partialFinalType) - } - - override def explainTerms(pw: RelWriter): RelWriter = { - val inputRowType = getInput.getRowType - super.explainTerms(pw) - .itemIf("groupBy", RelExplainUtil.fieldToString(grouping, inputRowType), - grouping.nonEmpty) - .itemIf("partialFinalType", partialFinalType, partialFinalType != PartialFinalType.NONE) - .item("select", RelExplainUtil.streamGroupAggregationToString( - inputRowType, - getRowType, - aggInfoList, - grouping, - isLocal = true)) - } - - //~ ExecNode methods ----------------------------------------------------------- - - override protected def translateToPlanInternal( - planner: StreamPlanner): Transformation[RowData] = { - val inputTransformation = getInputNodes.get(0).translateToPlan(planner) - .asInstanceOf[Transformation[RowData]] - val inRowType = FlinkTypeFactory.toLogicalRowType(getInput.getRowType) - val outRowType = FlinkTypeFactory.toLogicalRowType(outputRowType) - - val needRetraction = !ChangelogPlanUtils.inputInsertOnly(this) - - val generator = new AggsHandlerCodeGenerator( - CodeGeneratorContext(planner.getTableConfig), - planner.getRelBuilder, - inRowType.getChildren, - // the local aggregate result will be buffered, so need copy - copyInputField = true) - - generator - .needAccumulate() - .needMerge(mergedAccOffset = 0, mergedAccOnHeap = true) - - if (needRetraction) { - generator.needRetract() - } - - val aggsHandler = generator.generateAggsHandler("GroupAggsHandler", aggInfoList) - val aggFunction = new MiniBatchLocalGroupAggFunction(aggsHandler) - - val inputTypeInfo = inputTransformation.getOutputType.asInstanceOf[InternalTypeInfo[RowData]] - val selector = KeySelectorUtil.getRowDataSelector(grouping, inputTypeInfo) - - val operator = new MapBundleOperator( - aggFunction, - AggregateUtil.createMiniBatchTrigger(planner.getTableConfig), - selector.asInstanceOf[KeySelector[RowData, RowData]]) - - val transformation = new OneInputTransformation( - inputTransformation, - getRelDetailedDescription, - operator, - InternalTypeInfo.of(outRowType), - inputTransformation.getParallelism) - - if (inputsContainSingleton()) { - transformation.setParallelism(1) - transformation.setMaxParallelism(1) - } - - transformation - } -} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalGroupAggregate.scala new file mode 100644 index 0000000..f750e6e --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalGroupAggregate.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.planner.plan.nodes.physical.stream + +import org.apache.flink.table.planner.calcite.FlinkTypeFactory +import org.apache.flink.table.planner.plan.PartialFinalType +import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecLocalGroupAggregate +import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode} +import org.apache.flink.table.planner.plan.utils._ + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.{RelNode, RelWriter} + +import java.util + +/** + * Stream physical RelNode for unbounded local group aggregate. + * + * @see [[StreamPhysicalGroupAggregateBase]] for more info. + */ +class StreamPhysicalLocalGroupAggregate( + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputRel: RelNode, + val grouping: Array[Int], + val aggCalls: Seq[AggregateCall], + aggCallNeedRetractions: Array[Boolean], + needRetraction: Boolean, + val partialFinalType: PartialFinalType) + extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) { + + private lazy val aggInfoList = AggregateUtil.transformToStreamAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(inputRel.getRowType), + aggCalls, + aggCallNeedRetractions, + needRetraction, + isStateBackendDataViews = false) + + override def requireWatermark: Boolean = false + + override def deriveRowType(): RelDataType = { + AggregateUtil.inferLocalAggRowType( + aggInfoList, + inputRel.getRowType, + grouping, + getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]) + } + + override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { + new StreamPhysicalLocalGroupAggregate( + cluster, + traitSet, + inputs.get(0), + grouping, + aggCalls, + aggCallNeedRetractions, + needRetraction, + partialFinalType) + } + + override def explainTerms(pw: RelWriter): RelWriter = { + val inputRowType = getInput.getRowType + super.explainTerms(pw) + .itemIf("groupBy", RelExplainUtil.fieldToString(grouping, inputRowType), + grouping.nonEmpty) + .itemIf("partialFinalType", partialFinalType, partialFinalType != PartialFinalType.NONE) + .item("select", RelExplainUtil.streamGroupAggregationToString( + inputRowType, + getRowType, + aggInfoList, + grouping, + isLocal = true)) + } + + override def translateToExecNode(): ExecNode[_] = { + new StreamExecLocalGroupAggregate( + grouping, + aggCalls.toArray, + aggCallNeedRetractions, + needRetraction, + ExecEdge.DEFAULT, + FlinkTypeFactory.toLogicalRowType(getRowType), + getRelDetailedDescription + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala index a0bb8fd..ce8b563 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala @@ -22,7 +22,7 @@ import org.apache.flink.configuration.ConfigOption import org.apache.flink.configuration.ConfigOptions.key import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory} import org.apache.flink.table.planner.plan.PartialFinalType -import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGlobalGroupAggregate, StreamExecIncrementalGroupAggregate, StreamExecLocalGroupAggregate, StreamPhysicalExchange} +import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGlobalGroupAggregate, StreamExecIncrementalGroupAggregate, StreamPhysicalLocalGroupAggregate, StreamPhysicalExchange} import org.apache.flink.table.planner.plan.utils.{AggregateInfoList, AggregateUtil, DistinctInfo} import org.apache.flink.util.Preconditions @@ -34,21 +34,21 @@ import java.util.Collections /** * Rule that matches final [[StreamExecGlobalGroupAggregate]] on [[StreamPhysicalExchange]] - * on final [[StreamExecLocalGroupAggregate]] on partial [[StreamExecGlobalGroupAggregate]], - * and combines the final [[StreamExecLocalGroupAggregate]] and + * on final [[StreamPhysicalLocalGroupAggregate]] on partial [[StreamExecGlobalGroupAggregate]], + * and combines the final [[StreamPhysicalLocalGroupAggregate]] and * the partial [[StreamExecGlobalGroupAggregate]] into a [[StreamExecIncrementalGroupAggregate]]. */ class IncrementalAggregateRule extends RelOptRule( operand(classOf[StreamExecGlobalGroupAggregate], // final global agg operand(classOf[StreamPhysicalExchange], // key by - operand(classOf[StreamExecLocalGroupAggregate], // final local agg + operand(classOf[StreamPhysicalLocalGroupAggregate], // final local agg operand(classOf[StreamExecGlobalGroupAggregate], any())))), // partial global agg "IncrementalAggregateRule") { override def matches(call: RelOptRuleCall): Boolean = { val finalGlobalAgg: StreamExecGlobalGroupAggregate = call.rel(0) - val finalLocalAgg: StreamExecLocalGroupAggregate = call.rel(2) + val finalLocalAgg: StreamPhysicalLocalGroupAggregate = call.rel(2) val partialGlobalAgg: StreamExecGlobalGroupAggregate = call.rel(3) val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig @@ -66,7 +66,7 @@ class IncrementalAggregateRule override def onMatch(call: RelOptRuleCall): Unit = { val finalGlobalAgg: StreamExecGlobalGroupAggregate = call.rel(0) val exchange: StreamPhysicalExchange = call.rel(1) - val finalLocalAgg: StreamExecLocalGroupAggregate = call.rel(2) + val finalLocalAgg: StreamPhysicalLocalGroupAggregate = call.rel(2) val partialGlobalAgg: StreamExecGlobalGroupAggregate = call.rel(3) val aggInputRowType = partialGlobalAgg.inputRowType diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala index 7f01e5c..2f804cc 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala @@ -28,9 +28,11 @@ import org.apache.flink.table.planner.plan.rules.physical.FlinkExpandConversionR import org.apache.flink.table.planner.plan.utils.{AggregateInfoList, AggregateUtil, ChangelogPlanUtils} import org.apache.flink.table.planner.utils.AggregatePhaseStrategy import org.apache.flink.table.planner.utils.TableConfigUtils.getAggPhaseStrategy + import org.apache.calcite.plan.RelOptRule.{any, operand} import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.core.AggregateCall import java.util @@ -46,7 +48,7 @@ import java.util * {{{ * StreamExecGlobalGroupAggregate * +- StreamPhysicalExchange - * +- StreamExecLocalGroupAggregate + * +- StreamPhysicalLocalGroupAggregate * +- input of exchange * }}} */ @@ -97,52 +99,43 @@ class TwoStageOptimizedAggregateRule extends RelOptRule( realInput.asInstanceOf[StreamPhysicalRel]) val fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery) val monotonicity = fmq.getRelModifiedMonotonicity(agg) - val needRetractionArray = AggregateUtil.deriveAggCallNeedRetractions( + val aggCallNeedRetractions = AggregateUtil.deriveAggCallNeedRetractions( agg.grouping.length, agg.aggCalls, needRetraction, monotonicity) - val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( - FlinkTypeFactory.toLogicalRowType(realInput.getRowType), - agg.aggCalls, - needRetractionArray, - needRetraction, - isStateBackendDataViews = false) - val globalAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( FlinkTypeFactory.toLogicalRowType(realInput.getRowType), agg.aggCalls, - needRetractionArray, + aggCallNeedRetractions, needRetraction, isStateBackendDataViews = true) - val globalHashAgg = createTwoStageAgg(realInput, localAggInfoList, globalAggInfoList, agg) + val globalHashAgg = createTwoStageAgg( + realInput, agg.aggCalls, aggCallNeedRetractions, needRetraction, globalAggInfoList, agg) call.transformTo(globalHashAgg) } // the difference between localAggInfos and aggInfos is local agg use heap dataview, // but global agg use state dataview private def createTwoStageAgg( - input: RelNode, - localAggInfoList: AggregateInfoList, + realInput: RelNode, + aggCalls: Seq[AggregateCall], + aggCallNeedRetractions: Array[Boolean], + needRetraction: Boolean, globalAggInfoList: AggregateInfoList, agg: StreamPhysicalGroupAggregate): StreamExecGlobalGroupAggregate = { - val localAggRowType = AggregateUtil.inferLocalAggRowType( - localAggInfoList, - input.getRowType, - agg.grouping, - input.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]) // local agg shouldn't produce insert only messages - val localAggTraitSet = input.getTraitSet + val localAggTraitSet = realInput.getTraitSet .plus(ModifyKindSetTrait.INSERT_ONLY) .plus(UpdateKindTrait.NONE) - val localHashAgg = new StreamExecLocalGroupAggregate( + val localHashAgg = new StreamPhysicalLocalGroupAggregate( agg.getCluster, localAggTraitSet, - input, - localAggRowType, + realInput, agg.grouping, agg.aggCalls, - localAggInfoList, + aggCallNeedRetractions, + needRetraction, agg.partialFinalType) // grouping keys is forwarded by local agg, use indices instead of groupings @@ -153,11 +146,18 @@ class TwoStageOptimizedAggregateRule extends RelOptRule( FlinkConventions.STREAM_PHYSICAL, localHashAgg, globalDistribution) val globalAggProvidedTraitSet = agg.getTraitSet + // TODO Temporary solution, remove it later + val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(realInput.getRowType), + aggCalls, + aggCallNeedRetractions, + needRetraction, + isStateBackendDataViews = false) new StreamExecGlobalGroupAggregate( agg.getCluster, globalAggProvidedTraitSet, newInput, - input.getRowType, + realInput.getRowType, agg.getRowType, globalGrouping, localAggInfoList, diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala index f0e623f..ec2d55c 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala @@ -1019,23 +1019,16 @@ class FlinkRelMdHandlerTestBase { aggCallToAggFunction, isMerge = false) - val needRetractionArray = AggregateUtil.deriveAggCallNeedRetractions( + val aggCallNeedRetractions = AggregateUtil.deriveAggCallNeedRetractions( 1, aggCalls, needRetraction = false, null) - - val localAggInfoList = transformToStreamAggregateInfoList( - FlinkTypeFactory.toLogicalRowType(studentStreamScan.getRowType), - aggCalls, - needRetractionArray, - needInputCount = false, - isStateBackendDataViews = false) - val streamLocalAgg = new StreamExecLocalGroupAggregate( + val streamLocalAgg = new StreamPhysicalLocalGroupAggregate( cluster, streamPhysicalTraits, studentStreamScan, - rowTypeOfLocalAgg, Array(3), aggCalls, - localAggInfoList, + aggCallNeedRetractions, + false, PartialFinalType.NONE) val streamExchange1 = new StreamPhysicalExchange( @@ -1043,9 +1036,16 @@ class FlinkRelMdHandlerTestBase { val globalAggInfoList = transformToStreamAggregateInfoList( FlinkTypeFactory.toLogicalRowType(streamExchange1.getRowType), aggCalls, - needRetractionArray, + aggCallNeedRetractions, needInputCount = false, isStateBackendDataViews = true) + // TODO Temporary solution, remove it later + val localAggInfoList = transformToStreamAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(studentStreamScan.getRowType), + aggCalls, + aggCallNeedRetractions, + needInputCount = false, + isStateBackendDataViews = false) val streamGlobalAgg = new StreamExecGlobalGroupAggregate( cluster, streamPhysicalTraits,
