This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 1eaf54b73de34fe86057675be090032746cd6049 Author: godfreyhe <[email protected]> AuthorDate: Wed Dec 23 16:58:38 2020 +0800 [FLINK-20737][table-planner-blink] Introduce StreamPhysicalIncrementalGroupAggregate, and make StreamExecIncrementalGroupAggregate only extended from ExecNode This closes #14478 --- .../StreamExecIncrementalGroupAggregate.java | 188 +++++++++++++++++++ .../plan/metadata/FlinkRelMdColumnInterval.scala | 8 +- .../metadata/FlinkRelMdModifiedMonotonicity.scala | 2 +- .../StreamExecIncrementalGroupAggregate.scala | 205 --------------------- .../StreamPhysicalGlobalGroupAggregate.scala | 4 +- .../StreamPhysicalIncrementalGroupAggregate.scala | 132 +++++++++++++ .../physical/stream/IncrementalAggregateRule.scala | 73 +++----- .../table/planner/plan/utils/AggregateUtil.scala | 59 +++++- 8 files changed, 407 insertions(+), 264 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecIncrementalGroupAggregate.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecIncrementalGroupAggregate.java new file mode 100644 index 0000000..cab48e4 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecIncrementalGroupAggregate.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.exec.stream; + +import org.apache.flink.api.dag.Transformation; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.transformations.OneInputTransformation; +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.planner.codegen.CodeGeneratorContext; +import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator; +import org.apache.flink.table.planner.delegation.PlannerBase; +import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNode; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase; +import org.apache.flink.table.planner.plan.utils.AggregateInfoList; +import org.apache.flink.table.planner.plan.utils.AggregateUtil; +import org.apache.flink.table.planner.plan.utils.KeySelectorUtil; +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; +import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction; +import org.apache.flink.table.runtime.keyselector.RowDataKeySelector; +import org.apache.flink.table.runtime.operators.aggregate.MiniBatchIncrementalGroupAggFunction; +import org.apache.flink.table.runtime.operators.bundle.KeyedMapBundleOperator; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.RowType; + +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.tools.RelBuilder; + +import java.util.Arrays; +import java.util.Collections; + +/** Stream {@link ExecNode} for unbounded incremental group aggregate. */ +public class StreamExecIncrementalGroupAggregate extends ExecNodeBase<RowData> + implements StreamExecNode<RowData> { + + /** The partial agg's grouping. */ + private final int[] partialAggGrouping; + /** The final agg's grouping. */ + private final int[] finalAggGrouping; + /** The partial agg's original agg calls. */ + private final AggregateCall[] partialOriginalAggCalls; + /** Each element indicates whether the corresponding agg call needs `retract` method. */ + private final boolean[] partialAggCallNeedRetractions; + /** The input row type of this node's partial local agg. */ + private final RowType partialLocalAggInputRowType; + /** Whether this node consumes retraction messages. */ + private final boolean partialAggNeedRetraction; + + public StreamExecIncrementalGroupAggregate( + int[] partialAggGrouping, + int[] finalAggGrouping, + AggregateCall[] partialOriginalAggCalls, + boolean[] partialAggCallNeedRetractions, + RowType partialLocalAggInputRowType, + boolean partialAggNeedRetraction, + ExecEdge inputEdge, + RowType outputType, + String description) { + super(Collections.singletonList(inputEdge), outputType, description); + this.partialAggGrouping = partialAggGrouping; + this.finalAggGrouping = finalAggGrouping; + this.partialOriginalAggCalls = partialOriginalAggCalls; + this.partialAggCallNeedRetractions = partialAggCallNeedRetractions; + this.partialLocalAggInputRowType = partialLocalAggInputRowType; + this.partialAggNeedRetraction = partialAggNeedRetraction; + } + + @SuppressWarnings("unchecked") + @Override + protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) { + final TableConfig config = planner.getTableConfig(); + final ExecNode<RowData> inputNode = (ExecNode<RowData>) getInputNodes().get(0); + final Transformation<RowData> inputTransform = inputNode.translateToPlan(planner); + + final AggregateInfoList partialLocalAggInfoList = + AggregateUtil.createPartialAggInfoList( + partialLocalAggInputRowType, + JavaScalaConversionUtil.toScala(Arrays.asList(partialOriginalAggCalls)), + partialAggCallNeedRetractions, + partialAggNeedRetraction, + false); + + final GeneratedAggsHandleFunction partialAggsHandler = + generateAggsHandler( + "PartialGroupAggsHandler", + partialLocalAggInfoList, + partialAggGrouping.length, + partialLocalAggInfoList.getAccTypes(), + config, + planner.getRelBuilder(), + // the partial aggregate accumulators will be buffered, so need copy + true); + + final AggregateInfoList incrementalAggInfo = + AggregateUtil.createIncrementalAggInfoList( + partialLocalAggInputRowType, + JavaScalaConversionUtil.toScala(Arrays.asList(partialOriginalAggCalls)), + partialAggCallNeedRetractions, + partialAggNeedRetraction); + + final GeneratedAggsHandleFunction finalAggsHandler = + generateAggsHandler( + "FinalGroupAggsHandler", + incrementalAggInfo, + 0, + partialLocalAggInfoList.getAccTypes(), + config, + planner.getRelBuilder(), + // the final aggregate accumulators is not buffered + false); + + final RowDataKeySelector partialKeySelector = + KeySelectorUtil.getRowDataSelector( + partialAggGrouping, InternalTypeInfo.of(inputNode.getOutputType())); + final RowDataKeySelector finalKeySelector = + KeySelectorUtil.getRowDataSelector( + finalAggGrouping, partialKeySelector.getProducedType()); + + final MiniBatchIncrementalGroupAggFunction aggFunction = + new MiniBatchIncrementalGroupAggFunction( + partialAggsHandler, + finalAggsHandler, + finalKeySelector, + config.getIdleStateRetention().toMillis()); + + final OneInputStreamOperator<RowData, RowData> operator = + new KeyedMapBundleOperator<>( + aggFunction, AggregateUtil.createMiniBatchTrigger(config)); + + // partitioned aggregation + final OneInputTransformation<RowData, RowData> transform = + new OneInputTransformation<>( + inputTransform, + getDesc(), + operator, + InternalTypeInfo.of(getOutputType()), + inputTransform.getParallelism()); + + if (inputsContainSingleton()) { + transform.setParallelism(1); + transform.setMaxParallelism(1); + } + + // set KeyType and Selector for state + transform.setStateKeySelector(partialKeySelector); + transform.setStateKeyType(partialKeySelector.getProducedType()); + return transform; + } + + private GeneratedAggsHandleFunction generateAggsHandler( + String name, + AggregateInfoList aggInfoList, + int mergedAccOffset, + DataType[] mergedAccExternalTypes, + TableConfig config, + RelBuilder relBuilder, + boolean inputFieldCopy) { + + AggsHandlerCodeGenerator generator = + new AggsHandlerCodeGenerator( + new CodeGeneratorContext(config), + relBuilder, + JavaScalaConversionUtil.toScala(partialLocalAggInputRowType.getChildren()), + inputFieldCopy); + + return generator + .needMerge(mergedAccOffset, true, mergedAccExternalTypes) + .generateAggsHandler(name, aggInfoList); + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala index 5153643..10ec872 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala @@ -538,7 +538,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { case agg: StreamPhysicalGroupAggregate => agg.grouping case agg: StreamPhysicalLocalGroupAggregate => agg.grouping case agg: StreamPhysicalGlobalGroupAggregate => agg.grouping - case agg: StreamExecIncrementalGroupAggregate => agg.partialAggGrouping + case agg: StreamPhysicalIncrementalGroupAggregate => agg.partialAggGrouping case agg: StreamExecGroupWindowAggregate => agg.getGrouping case agg: BatchExecGroupAggregateBase => agg.getGrouping ++ agg.getAuxGrouping case agg: Aggregate => AggregateUtil.checkAndGetFullGroupSet(agg) @@ -608,9 +608,9 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } case agg: StreamPhysicalLocalGroupAggregate => getAggCallFromLocalAgg(aggCallIndex, agg.aggCalls, agg.getInput.getRowType) - case agg: StreamExecIncrementalGroupAggregate - if agg.partialAggInfoList.getActualAggregateCalls.length > aggCallIndex => - agg.partialAggInfoList.getActualAggregateCalls(aggCallIndex) + case agg: StreamPhysicalIncrementalGroupAggregate + if agg.partialAggCalls.length > aggCallIndex => + agg.partialAggCalls(aggCallIndex) case agg: StreamExecGroupWindowAggregate if agg.aggCalls.length > aggCallIndex => agg.aggCalls(aggCallIndex) case agg: BatchExecLocalHashAggregate => diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala index 272b557..9b0932a 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala @@ -303,7 +303,7 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon } def getRelModifiedMonotonicity( - rel: StreamExecIncrementalGroupAggregate, + rel: StreamPhysicalIncrementalGroupAggregate, mq: RelMetadataQuery): RelModifiedMonotonicity = { getRelModifiedMonotonicityOnAggregate( rel.getInput, mq, rel.finalAggCalls.toList, rel.finalAggGrouping) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala deleted file mode 100644 index bcb6a93..0000000 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.planner.plan.nodes.physical.stream - -import org.apache.flink.api.dag.Transformation -import org.apache.flink.streaming.api.transformations.OneInputTransformation -import org.apache.flink.table.api.TableConfig -import org.apache.flink.table.data.RowData -import org.apache.flink.table.planner.calcite.FlinkTypeFactory -import org.apache.flink.table.planner.codegen.CodeGeneratorContext -import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator -import org.apache.flink.table.planner.delegation.StreamPlanner -import org.apache.flink.table.planner.plan.nodes.exec.LegacyStreamExecNode -import org.apache.flink.table.planner.plan.utils.{KeySelectorUtil, _} -import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction -import org.apache.flink.table.runtime.operators.aggregate.MiniBatchIncrementalGroupAggFunction -import org.apache.flink.table.runtime.operators.bundle.KeyedMapBundleOperator -import org.apache.flink.table.runtime.typeutils.InternalTypeInfo -import org.apache.flink.table.types.DataType - -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rel.core.AggregateCall -import org.apache.calcite.rel.{RelNode, RelWriter} -import org.apache.calcite.tools.RelBuilder - -import java.util - -import scala.collection.JavaConversions._ - -/** - * Stream physical RelNode for unbounded incremental group aggregate. - * - * <p>Considering the following sub-plan: - * {{{ - * StreamPhysicalGlobalGroupAggregate (final-global-aggregate) - * +- StreamPhysicalExchange - * +- StreamPhysicalLocalGroupAggregate (final-local-aggregate) - * +- StreamPhysicalGlobalGroupAggregate (partial-global-aggregate) - * +- StreamPhysicalExchange - * +- StreamPhysicalLocalGroupAggregate (partial-local-aggregate) - * }}} - * - * partial-global-aggregate and final-local-aggregate can be combined as - * this node to share [[org.apache.flink.api.common.state.State]]. - * now the sub-plan is - * {{{ - * StreamPhysicalGlobalGroupAggregate (final-global-aggregate) - * +- StreamPhysicalExchange - * +- StreamExecIncrementalGroupAggregate - * +- StreamPhysicalExchange - * +- StreamPhysicalLocalGroupAggregate (partial-local-aggregate) - * }}} - * - * @see [[StreamPhysicalGroupAggregateBase]] for more info. - */ -class StreamExecIncrementalGroupAggregate( - cluster: RelOptCluster, - traitSet: RelTraitSet, - inputRel: RelNode, - inputRowType: RelDataType, - outputRowType: RelDataType, - val partialAggInfoList: AggregateInfoList, - finalAggInfoList: AggregateInfoList, - val finalAggCalls: Seq[AggregateCall], - val finalAggGrouping: Array[Int], - val partialAggGrouping: Array[Int]) - extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) - with LegacyStreamExecNode[RowData] { - - override def deriveRowType(): RelDataType = outputRowType - - override def requireWatermark: Boolean = false - - override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { - new StreamExecIncrementalGroupAggregate( - cluster, - traitSet, - inputs.get(0), - inputRowType, - outputRowType, - partialAggInfoList, - finalAggInfoList, - finalAggCalls, - finalAggGrouping, - partialAggGrouping) - } - - override def explainTerms(pw: RelWriter): RelWriter = { - super.explainTerms(pw) - .item("partialAggGrouping", - RelExplainUtil.fieldToString(partialAggGrouping, inputRel.getRowType)) - .item("finalAggGrouping", - RelExplainUtil.fieldToString(finalAggGrouping, inputRel.getRowType)) - .item("select", RelExplainUtil.streamGroupAggregationToString( - inputRel.getRowType, - getRowType, - finalAggInfoList, - finalAggGrouping, - shuffleKey = Some(partialAggGrouping))) - } - - //~ ExecNode methods ----------------------------------------------------------- - - override protected def translateToPlanInternal( - planner: StreamPlanner): Transformation[RowData] = { - val config = planner.getTableConfig - val inputTransformation = getInputNodes.get(0).translateToPlan(planner) - .asInstanceOf[Transformation[RowData]] - - val inRowType = FlinkTypeFactory.toLogicalRowType(inputRel.getRowType) - val outRowType = FlinkTypeFactory.toLogicalRowType(outputRowType) - - val partialAggsHandler = generateAggsHandler( - "PartialGroupAggsHandler", - partialAggInfoList, - mergedAccOffset = partialAggGrouping.length, - partialAggInfoList.getAccTypes, - config, - planner.getRelBuilder, - // the partial aggregate accumulators will be buffered, so need copy - inputFieldCopy = true) - - val finalAggsHandler = generateAggsHandler( - "FinalGroupAggsHandler", - finalAggInfoList, - mergedAccOffset = 0, - partialAggInfoList.getAccTypes, - config, - planner.getRelBuilder, - // the final aggregate accumulators is not buffered - inputFieldCopy = false) - - val partialKeySelector = KeySelectorUtil.getRowDataSelector( - partialAggGrouping, - InternalTypeInfo.of(inRowType)) - val finalKeySelector = KeySelectorUtil.getRowDataSelector( - finalAggGrouping, - partialKeySelector.getProducedType) - - val aggFunction = new MiniBatchIncrementalGroupAggFunction( - partialAggsHandler, - finalAggsHandler, - finalKeySelector, - config.getIdleStateRetention.toMillis) - - val operator = new KeyedMapBundleOperator( - aggFunction, - AggregateUtil.createMiniBatchTrigger(config)) - - // partitioned aggregation - val ret = new OneInputTransformation( - inputTransformation, - getRelDetailedDescription, - operator, - InternalTypeInfo.of(outRowType), - inputTransformation.getParallelism) - - if (inputsContainSingleton()) { - ret.setParallelism(1) - ret.setMaxParallelism(1) - } - - // set KeyType and Selector for state - ret.setStateKeySelector(partialKeySelector) - ret.setStateKeyType(partialKeySelector.getProducedType) - ret - } - - def generateAggsHandler( - name: String, - aggInfoList: AggregateInfoList, - mergedAccOffset: Int, - mergedAccExternalTypes: Array[DataType], - config: TableConfig, - relBuilder: RelBuilder, - inputFieldCopy: Boolean): GeneratedAggsHandleFunction = { - - val generator = new AggsHandlerCodeGenerator( - CodeGeneratorContext(config), - relBuilder, - FlinkTypeFactory.toLogicalRowType(inputRowType).getChildren, - inputFieldCopy) - - generator - .needAccumulate() - .needMerge(mergedAccOffset, mergedAccOnHeap = true, mergedAccExternalTypes) - .generateAggsHandler(name, aggInfoList) - } -} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalGroupAggregate.scala index b4e0e2e..7cec0fc 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalGroupAggregate.scala @@ -40,9 +40,9 @@ class StreamPhysicalGlobalGroupAggregate( outputRowType: RelDataType, val grouping: Array[Int], val aggCalls: Seq[AggregateCall], - aggCallNeedRetractions: Array[Boolean], + val aggCallNeedRetractions: Array[Boolean], val localAggInputRowType: RelDataType, - needRetraction: Boolean, + val needRetraction: Boolean, val partialFinalType: PartialFinalType) extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalIncrementalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalIncrementalGroupAggregate.scala new file mode 100644 index 0000000..cb203b7 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalIncrementalGroupAggregate.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.planner.plan.nodes.physical.stream + +import org.apache.flink.table.planner.calcite.FlinkTypeFactory +import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecIncrementalGroupAggregate +import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode} +import org.apache.flink.table.planner.plan.utils._ + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.{RelNode, RelWriter} + +import java.util + +/** + * Stream physical RelNode for unbounded incremental group aggregate. + * + * <p>Considering the following sub-plan: + * {{{ + * StreamPhysicalGlobalGroupAggregate (final-global-aggregate) + * +- StreamPhysicalExchange + * +- StreamPhysicalLocalGroupAggregate (final-local-aggregate) + * +- StreamPhysicalGlobalGroupAggregate (partial-global-aggregate) + * +- StreamPhysicalExchange + * +- StreamPhysicalLocalGroupAggregate (partial-local-aggregate) + * }}} + * + * partial-global-aggregate and final-local-aggregate can be combined as + * this node to share [[org.apache.flink.api.common.state.State]]. + * now the sub-plan is + * {{{ + * StreamPhysicalGlobalGroupAggregate (final-global-aggregate) + * +- StreamPhysicalExchange + * +- StreamPhysicalIncrementalGroupAggregate + * +- StreamPhysicalExchange + * +- StreamPhysicalLocalGroupAggregate (partial-local-aggregate) + * }}} + * + * @see [[StreamPhysicalGroupAggregateBase]] for more info. + */ +class StreamPhysicalIncrementalGroupAggregate( + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputRel: RelNode, + val partialAggGrouping: Array[Int], + val partialAggCalls: Array[AggregateCall], + val finalAggGrouping: Array[Int], + val finalAggCalls: Array[AggregateCall], + partialOriginalAggCalls: Array[AggregateCall], + partialAggCallNeedRetractions: Array[Boolean], + partialAggNeedRetraction: Boolean, + partialLocalAggInputRowType: RelDataType, + partialGlobalAggRowType: RelDataType) + extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) { + + private lazy val incrementalAggInfo = AggregateUtil.createIncrementalAggInfoList( + FlinkTypeFactory.toLogicalRowType(partialLocalAggInputRowType), + partialOriginalAggCalls, + partialAggCallNeedRetractions, + partialAggNeedRetraction) + + override def deriveRowType(): RelDataType = { + AggregateUtil.inferLocalAggRowType( + incrementalAggInfo, + partialGlobalAggRowType, + finalAggGrouping, + getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]) + } + + override def requireWatermark: Boolean = false + + override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { + new StreamPhysicalIncrementalGroupAggregate( + cluster, + traitSet, + inputs.get(0), + partialAggGrouping, + partialAggCalls, + finalAggGrouping, + finalAggCalls, + partialOriginalAggCalls, + partialAggCallNeedRetractions, + partialAggNeedRetraction, + partialLocalAggInputRowType, + partialGlobalAggRowType) + } + + override def explainTerms(pw: RelWriter): RelWriter = { + super.explainTerms(pw) + .item("partialAggGrouping", + RelExplainUtil.fieldToString(partialAggGrouping, inputRel.getRowType)) + .item("finalAggGrouping", + RelExplainUtil.fieldToString(finalAggGrouping, inputRel.getRowType)) + .item("select", RelExplainUtil.streamGroupAggregationToString( + inputRel.getRowType, + getRowType, + incrementalAggInfo, + finalAggGrouping, + shuffleKey = Some(partialAggGrouping))) + } + + override def translateToExecNode(): ExecNode[_] = { + new StreamExecIncrementalGroupAggregate( + partialAggGrouping, + finalAggGrouping, + partialOriginalAggCalls, + partialAggCallNeedRetractions, + FlinkTypeFactory.toLogicalRowType(partialLocalAggInputRowType), + partialAggNeedRetraction, + ExecEdge.DEFAULT, + FlinkTypeFactory.toLogicalRowType(getRowType), + getRelDetailedDescription + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala index ee8f074..235c441 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala @@ -22,8 +22,8 @@ import org.apache.flink.configuration.ConfigOption import org.apache.flink.configuration.ConfigOptions.key import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory} import org.apache.flink.table.planner.plan.PartialFinalType -import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecIncrementalGroupAggregate, StreamPhysicalExchange, StreamPhysicalGlobalGroupAggregate, StreamPhysicalLocalGroupAggregate} -import org.apache.flink.table.planner.plan.utils.{AggregateInfoList, AggregateUtil, DistinctInfo} +import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamPhysicalExchange, StreamPhysicalGlobalGroupAggregate, StreamPhysicalIncrementalGroupAggregate, StreamPhysicalLocalGroupAggregate} +import org.apache.flink.table.planner.plan.utils.AggregateUtil import org.apache.flink.util.Preconditions import org.apache.calcite.plan.RelOptRule.{any, operand} @@ -37,7 +37,7 @@ import java.util.Collections * on final [[StreamPhysicalLocalGroupAggregate]] on partial [[StreamPhysicalGlobalGroupAggregate]], * and combines the final [[StreamPhysicalLocalGroupAggregate]] and * the partial [[StreamPhysicalGlobalGroupAggregate]] into a - * [[StreamExecIncrementalGroupAggregate]]. + * [[StreamPhysicalIncrementalGroupAggregate]]. */ class IncrementalAggregateRule extends RelOptRule( @@ -69,55 +69,26 @@ class IncrementalAggregateRule val exchange: StreamPhysicalExchange = call.rel(1) val finalLocalAgg: StreamPhysicalLocalGroupAggregate = call.rel(2) val partialGlobalAgg: StreamPhysicalGlobalGroupAggregate = call.rel(3) - val aggInputRowType = partialGlobalAgg.localAggInputRowType - - val partialLocalAggInfoList = partialGlobalAgg.localAggInfoList - val partialGlobalAggInfoList = partialGlobalAgg.globalAggInfoList - val finalGlobalAggInfoList = finalGlobalAgg.globalAggInfoList - val aggCalls = finalGlobalAggInfoList.getActualAggregateCalls - - val typeFactory = finalGlobalAgg.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory] - - // pick distinct info from global which is on state, and modify excludeAcc parameter - val incrDistinctInfo = partialGlobalAggInfoList.distinctInfos.map { info => - DistinctInfo( - info.argIndexes, - info.keyType, - info.accType, - // exclude distinct acc from the aggregate accumulator, - // because the output acc only need to contain the count - excludeAcc = true, - info.dataViewSpec, - info.consumeRetraction, - info.filterArgs, - info.aggIndexes - ) - } - - val incrAggInfoList = AggregateInfoList( - // pick local aggs info from local which is on heap - partialLocalAggInfoList.aggInfos, - partialGlobalAggInfoList.indexOfCountStar, - partialGlobalAggInfoList.countStarInserted, - incrDistinctInfo) + val partialLocalAggInputRowType = partialGlobalAgg.localAggInputRowType - val incrAggOutputRowType = AggregateUtil.inferLocalAggRowType( - incrAggInfoList, - partialGlobalAgg.getRowType, - finalGlobalAgg.grouping, - typeFactory) + val partialOriginalAggCalls = partialGlobalAgg.aggCalls.toArray + val partialRealAggCalls = partialGlobalAgg.localAggInfoList.getActualAggregateCalls + val finalRealAggCalls = finalGlobalAgg.globalAggInfoList.getActualAggregateCalls - val incrAgg = new StreamExecIncrementalGroupAggregate( + val incrAgg = new StreamPhysicalIncrementalGroupAggregate( partialGlobalAgg.getCluster, finalLocalAgg.getTraitSet, // extends final local agg traits (ACC trait) partialGlobalAgg.getInput, - aggInputRowType, - incrAggOutputRowType, - partialLocalAggInfoList, - incrAggInfoList, - aggCalls, + partialGlobalAgg.grouping, + partialRealAggCalls, finalLocalAgg.grouping, - partialGlobalAgg.grouping) + finalRealAggCalls, + partialOriginalAggCalls, + partialGlobalAgg.aggCallNeedRetractions, + partialGlobalAgg.needRetraction, + partialLocalAggInputRowType, + partialGlobalAgg.getRowType) + val incrAggOutputRowType = incrAgg.getRowType val newExchange = exchange.copy(exchange.getTraitSet, incrAgg, exchange.distribution) @@ -135,9 +106,9 @@ class IncrementalAggregateRule val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( // the final agg input is partial agg FlinkTypeFactory.toLogicalRowType(partialGlobalAgg.getRowType), - aggCalls, + finalRealAggCalls, // all the aggs do not need retraction - Array.fill(aggCalls.length)(false), + Array.fill(finalRealAggCalls.length)(false), // also do not need count* needInputCount = false, // the local agg is not works on state @@ -148,7 +119,7 @@ class IncrementalAggregateRule localAggInfoList, incrAgg.getRowType, finalGlobalAgg.grouping, - typeFactory) + finalGlobalAgg.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]) Preconditions.checkState(RelOptUtil.areRowTypesEqual( incrAggOutputRowType, @@ -161,9 +132,9 @@ class IncrementalAggregateRule newExchange, finalGlobalAgg.getRowType, finalGlobalAgg.grouping, - aggCalls, + finalRealAggCalls, // all the aggs do not need retraction - Array.fill(aggCalls.length)(false), + Array.fill(finalRealAggCalls.length)(false), finalGlobalAgg.localAggInputRowType, needRetraction = false, finalGlobalAgg.partialFinalType) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index ea1205b..43f3e3e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -175,6 +175,62 @@ object AggregateUtil extends Enumeration { map } + def createPartialAggInfoList( + partialLocalAggInputRowType: RowType, + partialOriginalAggCalls: Seq[AggregateCall], + partialAggCallNeedRetractions: Array[Boolean], + partialAggNeedRetraction: Boolean, + isGlobal: Boolean): AggregateInfoList = { + transformToStreamAggregateInfoList( + partialLocalAggInputRowType, + partialOriginalAggCalls, + partialAggCallNeedRetractions, + partialAggNeedRetraction, + isStateBackendDataViews = isGlobal) + } + + def createIncrementalAggInfoList( + partialLocalAggInputRowType: RowType, + partialOriginalAggCalls: Seq[AggregateCall], + partialAggCallNeedRetractions: Array[Boolean], + partialAggNeedRetraction: Boolean): AggregateInfoList = { + val partialLocalAggInfoList = createPartialAggInfoList( + partialLocalAggInputRowType, + partialOriginalAggCalls, + partialAggCallNeedRetractions, + partialAggNeedRetraction, + isGlobal = false) + val partialGlobalAggInfoList = createPartialAggInfoList( + partialLocalAggInputRowType, + partialOriginalAggCalls, + partialAggCallNeedRetractions, + partialAggNeedRetraction, + isGlobal = true) + + // pick distinct info from global which is on state, and modify excludeAcc parameter + val incrementalDistinctInfos = partialGlobalAggInfoList.distinctInfos.map { info => + DistinctInfo( + info.argIndexes, + info.keyType, + info.accType, + // exclude distinct acc from the aggregate accumulator, + // because the output acc only need to contain the count + excludeAcc = true, + info.dataViewSpec, + info.consumeRetraction, + info.filterArgs, + info.aggIndexes + ) + } + + AggregateInfoList( + // pick local aggs info from local which is on heap + partialLocalAggInfoList.aggInfos, + partialGlobalAggInfoList.indexOfCountStar, + partialGlobalAggInfoList.countStarInserted, + incrementalDistinctInfos) + } + def deriveAggregateInfoList( agg: StreamPhysicalRel, groupCount: Int, @@ -532,6 +588,7 @@ object AggregateUtil extends Enumeration { /** * Inserts an COUNT(*) aggregate call if needed. The COUNT(*) aggregate call is used * to count the number of added and retracted input records. + * * @param needInputCount whether to insert an InputCount aggregate * @param aggregateCalls original aggregate calls * @return (indexOfCountStar, countStarInserted, newAggCalls) @@ -885,7 +942,7 @@ object AggregateUtil extends Enumeration { def createMiniBatchTrigger(tableConfig: TableConfig): CountBundleTrigger[RowData] = { val size = tableConfig.getConfiguration.getLong( ExecutionConfigOptions.TABLE_EXEC_MINIBATCH_SIZE) - if (size <= 0 ) { + if (size <= 0) { throw new IllegalArgumentException( ExecutionConfigOptions.TABLE_EXEC_MINIBATCH_SIZE + " must be > 0.") }
