This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 85e9f80b63b595e43e6926335d190a704c128ceb Author: godfreyhe <[email protected]> AuthorDate: Wed Dec 23 15:56:28 2020 +0800 [FLINK-20737][table-planner-blink] Introduce StreamPhysicalGlobalGroupAggregate, and make StreamExecGlobalGroupAggregate only extended from ExecNode This closes #14478 --- .../stream/StreamExecGlobalGroupAggregate.java | 226 +++++++++++++++++++++ .../plan/metadata/FlinkRelMdColumnInterval.scala | 10 +- .../plan/metadata/FlinkRelMdColumnUniqueness.scala | 2 +- .../FlinkRelMdFilteredColumnInterval.scala | 4 +- .../metadata/FlinkRelMdModifiedMonotonicity.scala | 2 +- .../plan/metadata/FlinkRelMdUniqueKeys.scala | 2 +- .../stream/StreamExecGlobalGroupAggregate.scala | 207 ------------------- .../StreamExecIncrementalGroupAggregate.scala | 6 +- .../StreamPhysicalGlobalGroupAggregate.scala | 108 ++++++++++ .../physical/stream/IncrementalAggregateRule.scala | 47 ++--- .../stream/TwoStageOptimizedAggregateRule.scala | 69 ++----- .../plan/metadata/FlinkRelMdHandlerTestBase.scala | 23 +-- 12 files changed, 393 insertions(+), 313 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalGroupAggregate.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalGroupAggregate.java new file mode 100644 index 0000000..fe273a7 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalGroupAggregate.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.exec.stream; + +import org.apache.flink.api.dag.Transformation; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.transformations.OneInputTransformation; +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.config.ExecutionConfigOptions; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.planner.codegen.CodeGeneratorContext; +import org.apache.flink.table.planner.codegen.EqualiserCodeGenerator; +import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator; +import org.apache.flink.table.planner.delegation.PlannerBase; +import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNode; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase; +import org.apache.flink.table.planner.plan.utils.AggregateInfoList; +import org.apache.flink.table.planner.plan.utils.AggregateUtil; +import org.apache.flink.table.planner.plan.utils.KeySelectorUtil; +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; +import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction; +import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.keyselector.RowDataKeySelector; +import org.apache.flink.table.runtime.operators.aggregate.MiniBatchGlobalGroupAggFunction; +import org.apache.flink.table.runtime.operators.bundle.KeyedMapBundleOperator; +import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; + +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.tools.RelBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.Collections; + +/** Stream {@link ExecNode} for unbounded global group aggregate. */ +public class StreamExecGlobalGroupAggregate extends ExecNodeBase<RowData> + implements StreamExecNode<RowData> { + private static final Logger LOG = LoggerFactory.getLogger(StreamExecGlobalGroupAggregate.class); + + private final int[] grouping; + private final AggregateCall[] aggCalls; + /** Each element indicates whether the corresponding agg call needs `retract` method. */ + private final boolean[] aggCallNeedRetractions; + /** The input row type of this node's local agg. */ + private final RowType localAggInputRowType; + /** Whether this node will generate UPDATE_BEFORE messages. */ + private final boolean generateUpdateBefore; + /** Whether this node consumes retraction messages. */ + private final boolean needRetraction; + + public StreamExecGlobalGroupAggregate( + int[] grouping, + AggregateCall[] aggCalls, + boolean[] aggCallNeedRetractions, + RowType localAggInputRowType, + boolean generateUpdateBefore, + boolean needRetraction, + ExecEdge inputEdge, + RowType outputType, + String description) { + super(Collections.singletonList(inputEdge), outputType, description); + this.grouping = grouping; + this.aggCalls = aggCalls; + this.aggCallNeedRetractions = aggCallNeedRetractions; + this.localAggInputRowType = localAggInputRowType; + this.generateUpdateBefore = generateUpdateBefore; + this.needRetraction = needRetraction; + } + + @SuppressWarnings("unchecked") + @Override + protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) { + final TableConfig tableConfig = planner.getTableConfig(); + + if (grouping.length > 0 && tableConfig.getMinIdleStateRetentionTime() < 0) { + LOG.warn( + "No state retention interval configured for a query which accumulates state. " + + "Please provide a query configuration with valid retention interval to prevent excessive " + + "state size. You may specify a retention time of 0 to not clean up the state."); + } + + final ExecNode<RowData> inputNode = (ExecNode<RowData>) getInputNodes().get(0); + final Transformation<RowData> inputTransform = inputNode.translateToPlan(planner); + final RowType inputRowType = (RowType) inputNode.getOutputType(); + + final AggregateInfoList localAggInfoList = + AggregateUtil.transformToStreamAggregateInfoList( + localAggInputRowType, + JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)), + aggCallNeedRetractions, + needRetraction, + false, + true); + final AggregateInfoList globalAggInfoList = + AggregateUtil.transformToStreamAggregateInfoList( + localAggInputRowType, + JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)), + aggCallNeedRetractions, + needRetraction, + true, + true); + + final GeneratedAggsHandleFunction localAggsHandler = + generateAggsHandler( + "LocalGroupAggsHandler", + localAggInfoList, + grouping.length, + localAggInfoList.getAccTypes(), + tableConfig, + planner.getRelBuilder()); + + final GeneratedAggsHandleFunction globalAggsHandler = + generateAggsHandler( + "GlobalGroupAggsHandler", + globalAggInfoList, + 0, + localAggInfoList.getAccTypes(), + tableConfig, + planner.getRelBuilder()); + + final int indexOfCountStar = globalAggInfoList.getIndexOfCountStar(); + final LogicalType[] globalAccTypes = + Arrays.stream(globalAggInfoList.getAccTypes()) + .map(LogicalTypeDataTypeConverter::fromDataTypeToLogicalType) + .toArray(LogicalType[]::new); + final LogicalType[] globalAggValueTypes = + Arrays.stream(globalAggInfoList.getActualValueTypes()) + .map(LogicalTypeDataTypeConverter::fromDataTypeToLogicalType) + .toArray(LogicalType[]::new); + final GeneratedRecordEqualiser recordEqualiser = + new EqualiserCodeGenerator(globalAggValueTypes) + .generateRecordEqualiser("GroupAggValueEqualiser"); + + final OneInputStreamOperator<RowData, RowData> operator; + final boolean isMiniBatchEnabled = + tableConfig + .getConfiguration() + .getBoolean(ExecutionConfigOptions.TABLE_EXEC_MINIBATCH_ENABLED); + if (isMiniBatchEnabled) { + MiniBatchGlobalGroupAggFunction aggFunction = + new MiniBatchGlobalGroupAggFunction( + localAggsHandler, + globalAggsHandler, + recordEqualiser, + globalAccTypes, + indexOfCountStar, + generateUpdateBefore, + tableConfig.getIdleStateRetention().toMillis()); + + operator = + new KeyedMapBundleOperator<>( + aggFunction, AggregateUtil.createMiniBatchTrigger(tableConfig)); + } else { + throw new TableException("Local-Global optimization is only worked in miniBatch mode"); + } + + // partitioned aggregation + final OneInputTransformation<RowData, RowData> transform = + new OneInputTransformation<>( + inputTransform, + getDesc(), + operator, + InternalTypeInfo.of(getOutputType()), + inputTransform.getParallelism()); + + if (inputsContainSingleton()) { + transform.setParallelism(1); + transform.setMaxParallelism(1); + } + + // set KeyType and Selector for state + final RowDataKeySelector selector = + KeySelectorUtil.getRowDataSelector(grouping, InternalTypeInfo.of(inputRowType)); + transform.setStateKeySelector(selector); + transform.setStateKeyType(selector.getProducedType()); + + return transform; + } + + private GeneratedAggsHandleFunction generateAggsHandler( + String name, + AggregateInfoList aggInfoList, + int mergedAccOffset, + DataType[] mergedAccExternalTypes, + TableConfig config, + RelBuilder relBuilder) { + + // For local aggregate, the result will be buffered, so copyInputField is true. + // For global aggregate, result will be put into state, then not need copy + // but this global aggregate result will be put into a buffered map first, + // then multi-put to state, so copyInputField is true. + AggsHandlerCodeGenerator generator = + new AggsHandlerCodeGenerator( + new CodeGeneratorContext(config), + relBuilder, + JavaScalaConversionUtil.toScala(localAggInputRowType.getChildren()), + true); + + return generator + .needMerge(mergedAccOffset, true, mergedAccExternalTypes) + .generateAggsHandler(name, aggInfoList); + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala index 9978208..5153643 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala @@ -472,7 +472,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { * @return interval of the given column on stream global group Aggregate */ def getColumnInterval( - aggregate: StreamExecGlobalGroupAggregate, + aggregate: StreamPhysicalGlobalGroupAggregate, mq: RelMetadataQuery, index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index) @@ -537,7 +537,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { val groupSet = aggregate match { case agg: StreamPhysicalGroupAggregate => agg.grouping case agg: StreamPhysicalLocalGroupAggregate => agg.grouping - case agg: StreamExecGlobalGroupAggregate => agg.grouping + case agg: StreamPhysicalGlobalGroupAggregate => agg.grouping case agg: StreamExecIncrementalGroupAggregate => agg.partialAggGrouping case agg: StreamExecGroupWindowAggregate => agg.getGrouping case agg: BatchExecGroupAggregateBase => agg.getGrouping ++ agg.getAuxGrouping @@ -597,10 +597,10 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { val aggCall = aggregate match { case agg: StreamPhysicalGroupAggregate if agg.aggCalls.length > aggCallIndex => agg.aggCalls(aggCallIndex) - case agg: StreamExecGlobalGroupAggregate - if agg.globalAggInfoList.getActualAggregateCalls.length > aggCallIndex => + case agg: StreamPhysicalGlobalGroupAggregate + if agg.aggCalls.length > aggCallIndex => val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg( - aggCallIndex, agg.globalAggInfoList.getActualAggregateCalls, agg.inputRowType) + aggCallIndex, agg.aggCalls, agg.localAggInputRowType) if (aggCallIndexInLocalAgg != null) { return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg) } else { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala index 13954d3..a111e79 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala @@ -349,7 +349,7 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata } def areColumnsUnique( - rel: StreamExecGlobalGroupAggregate, + rel: StreamPhysicalGlobalGroupAggregate, mq: RelMetadataQuery, columns: ImmutableBitSet, ignoreNulls: Boolean): JBoolean = { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala index 5444b3b..f8c6646 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala @@ -20,7 +20,7 @@ package org.apache.flink.table.planner.plan.metadata import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.FilteredColumnInterval import org.apache.flink.table.planner.plan.nodes.calcite.TableAggregate import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecGroupAggregateBase -import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGlobalGroupAggregate, StreamExecGroupTableAggregate, StreamExecGroupWindowAggregate, StreamExecGroupWindowTableAggregate, StreamPhysicalGroupAggregate, StreamPhysicalLocalGroupAggregate} +import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGroupTableAggregate, StreamExecGroupWindowAggregate, StreamExecGroupWindowTableAggregate, StreamPhysicalGlobalGroupAggregate, StreamPhysicalGroupAggregate, StreamPhysicalLocalGroupAggregate} import org.apache.flink.table.planner.plan.stats.ValueInterval import org.apache.flink.table.planner.plan.utils.ColumnIntervalUtil import org.apache.flink.util.Preconditions.checkArgument @@ -208,7 +208,7 @@ class FlinkRelMdFilteredColumnInterval private extends MetadataHandler[FilteredC } def getFilteredColumnInterval( - aggregate: StreamExecGlobalGroupAggregate, + aggregate: StreamPhysicalGlobalGroupAggregate, mq: RelMetadataQuery, columnIndex: Int, filterArg: Int): ValueInterval = { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala index 1f2e1ee..272b557 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala @@ -289,7 +289,7 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon } def getRelModifiedMonotonicity( - rel: StreamExecGlobalGroupAggregate, + rel: StreamPhysicalGlobalGroupAggregate, mq: RelMetadataQuery): RelModifiedMonotonicity = { // global and local agg should have same update monotonicity val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala index 96c5a9b..5962901 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala @@ -357,7 +357,7 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu ignoreNulls: Boolean): JSet[ImmutableBitSet] = null def getUniqueKeys( - rel: StreamExecGlobalGroupAggregate, + rel: StreamPhysicalGlobalGroupAggregate, mq: RelMetadataQuery, ignoreNulls: Boolean): JSet[ImmutableBitSet] = { getUniqueKeysOnAggregate(rel.grouping, mq, ignoreNulls) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGlobalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGlobalGroupAggregate.scala deleted file mode 100644 index 6c44b33..0000000 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGlobalGroupAggregate.scala +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.planner.plan.nodes.physical.stream - -import org.apache.flink.api.dag.Transformation -import org.apache.flink.streaming.api.transformations.OneInputTransformation -import org.apache.flink.table.api.config.ExecutionConfigOptions -import org.apache.flink.table.api.{TableConfig, TableException} -import org.apache.flink.table.data.RowData -import org.apache.flink.table.planner.calcite.FlinkTypeFactory -import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator -import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, EqualiserCodeGenerator} -import org.apache.flink.table.planner.delegation.StreamPlanner -import org.apache.flink.table.planner.plan.PartialFinalType -import org.apache.flink.table.planner.plan.nodes.exec.LegacyStreamExecNode -import org.apache.flink.table.planner.plan.utils.{KeySelectorUtil, _} -import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction -import org.apache.flink.table.runtime.operators.aggregate.MiniBatchGlobalGroupAggFunction -import org.apache.flink.table.runtime.operators.bundle.KeyedMapBundleOperator -import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType -import org.apache.flink.table.runtime.typeutils.InternalTypeInfo -import org.apache.flink.table.types.DataType - -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rel.{RelNode, RelWriter} -import org.apache.calcite.tools.RelBuilder - -import scala.collection.JavaConversions._ - -/** - * Stream physical RelNode for unbounded global group aggregate. - * - * @see [[StreamPhysicalGroupAggregateBase]] for more info. - */ -class StreamExecGlobalGroupAggregate( - cluster: RelOptCluster, - traitSet: RelTraitSet, - inputRel: RelNode, - val inputRowType: RelDataType, - outputRowType: RelDataType, - val grouping: Array[Int], - val localAggInfoList: AggregateInfoList, - val globalAggInfoList: AggregateInfoList, - val partialFinalType: PartialFinalType) - extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) - with LegacyStreamExecNode[RowData] { - - override def requireWatermark: Boolean = false - - override def deriveRowType(): RelDataType = outputRowType - - override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { - new StreamExecGlobalGroupAggregate( - cluster, - traitSet, - inputs.get(0), - inputRowType, - outputRowType, - grouping, - localAggInfoList, - globalAggInfoList, - partialFinalType) - } - - override def explainTerms(pw: RelWriter): RelWriter = { - super.explainTerms(pw) - .itemIf("groupBy", - RelExplainUtil.fieldToString(grouping, inputRel.getRowType), grouping.nonEmpty) - .itemIf("partialFinalType", partialFinalType, partialFinalType != PartialFinalType.NONE) - .item("select", RelExplainUtil.streamGroupAggregationToString( - inputRel.getRowType, - getRowType, - globalAggInfoList, - grouping, - isGlobal = true)) - } - - //~ ExecNode methods ----------------------------------------------------------- - - override protected def translateToPlanInternal( - planner: StreamPlanner): Transformation[RowData] = { - val tableConfig = planner.getTableConfig - - if (grouping.length > 0 && tableConfig.getMinIdleStateRetentionTime < 0) { - LOG.warn("No state retention interval configured for a query which accumulates state. " + - "Please provide a query configuration with valid retention interval to prevent excessive " + - "state size. You may specify a retention time of 0 to not clean up the state.") - } - - val inputTransformation = getInputNodes.get(0).translateToPlan(planner) - .asInstanceOf[Transformation[RowData]] - - val outRowType = FlinkTypeFactory.toLogicalRowType(outputRowType) - - val generateUpdateBefore = ChangelogPlanUtils.generateUpdateBefore(this) - - val localAggsHandler = generateAggsHandler( - "LocalGroupAggsHandler", - localAggInfoList, - mergedAccOffset = grouping.length, - mergedAccOnHeap = true, - localAggInfoList.getAccTypes, - tableConfig, - planner.getRelBuilder, - // the local aggregate result will be buffered, so need copy - inputFieldCopy = true) - - val globalAggsHandler = generateAggsHandler( - "GlobalGroupAggsHandler", - globalAggInfoList, - mergedAccOffset = 0, - mergedAccOnHeap = true, - localAggInfoList.getAccTypes, - tableConfig, - planner.getRelBuilder, - // if global aggregate result will be put into state, then not need copy - // but this global aggregate result will be put into a buffered map first, - // then multiput to state, so it need copy - inputFieldCopy = true) - - val indexOfCountStar = globalAggInfoList.getIndexOfCountStar - val globalAccTypes = globalAggInfoList.getAccTypes.map(fromDataTypeToLogicalType) - val globalAggValueTypes = globalAggInfoList - .getActualValueTypes - .map(fromDataTypeToLogicalType) - val recordEqualiser = new EqualiserCodeGenerator(globalAggValueTypes) - .generateRecordEqualiser("GroupAggValueEqualiser") - - val operator = if (tableConfig.getConfiguration.getBoolean( - ExecutionConfigOptions.TABLE_EXEC_MINIBATCH_ENABLED)) { - val aggFunction = new MiniBatchGlobalGroupAggFunction( - localAggsHandler, - globalAggsHandler, - recordEqualiser, - globalAccTypes, - indexOfCountStar, - generateUpdateBefore, - tableConfig.getIdleStateRetention.toMillis) - - new KeyedMapBundleOperator( - aggFunction, - AggregateUtil.createMiniBatchTrigger(tableConfig)) - } else { - throw new TableException("Local-Global optimization is only worked in miniBatch mode") - } - - val inputTypeInfo = inputTransformation.getOutputType.asInstanceOf[InternalTypeInfo[RowData]] - val selector = KeySelectorUtil.getRowDataSelector(grouping, inputTypeInfo) - - // partitioned aggregation - val ret = new OneInputTransformation( - inputTransformation, - getRelDetailedDescription, - operator, - InternalTypeInfo.of(outRowType), - inputTransformation.getParallelism) - - if (inputsContainSingleton()) { - ret.setParallelism(1) - ret.setMaxParallelism(1) - } - - // set KeyType and Selector for state - ret.setStateKeySelector(selector) - ret.setStateKeyType(selector.getProducedType) - ret - } - - def generateAggsHandler( - name: String, - aggInfoList: AggregateInfoList, - mergedAccOffset: Int, - mergedAccOnHeap: Boolean, - mergedAccExternalTypes: Array[DataType], - config: TableConfig, - relBuilder: RelBuilder, - inputFieldCopy: Boolean): GeneratedAggsHandleFunction = { - - val generator = new AggsHandlerCodeGenerator( - CodeGeneratorContext(config), - relBuilder, - FlinkTypeFactory.toLogicalRowType(inputRowType).getChildren, - inputFieldCopy) - - generator - .needAccumulate() - .needMerge(mergedAccOffset, mergedAccOnHeap, mergedAccExternalTypes) - .generateAggsHandler(name, aggInfoList) - } - -} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala index 556c4fc..bcb6a93 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala @@ -48,10 +48,10 @@ import scala.collection.JavaConversions._ * * <p>Considering the following sub-plan: * {{{ - * StreamExecGlobalGroupAggregate (final-global-aggregate) + * StreamPhysicalGlobalGroupAggregate (final-global-aggregate) * +- StreamPhysicalExchange * +- StreamPhysicalLocalGroupAggregate (final-local-aggregate) - * +- StreamExecGlobalGroupAggregate (partial-global-aggregate) + * +- StreamPhysicalGlobalGroupAggregate (partial-global-aggregate) * +- StreamPhysicalExchange * +- StreamPhysicalLocalGroupAggregate (partial-local-aggregate) * }}} @@ -60,7 +60,7 @@ import scala.collection.JavaConversions._ * this node to share [[org.apache.flink.api.common.state.State]]. * now the sub-plan is * {{{ - * StreamExecGlobalGroupAggregate (final-global-aggregate) + * StreamPhysicalGlobalGroupAggregate (final-global-aggregate) * +- StreamPhysicalExchange * +- StreamExecIncrementalGroupAggregate * +- StreamPhysicalExchange diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalGroupAggregate.scala new file mode 100644 index 0000000..b4e0e2e --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalGroupAggregate.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.planner.plan.nodes.physical.stream + +import org.apache.flink.table.planner.calcite.FlinkTypeFactory +import org.apache.flink.table.planner.plan.PartialFinalType +import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecGlobalGroupAggregate +import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode} +import org.apache.flink.table.planner.plan.utils._ + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.{RelNode, RelWriter} + +/** + * Stream physical RelNode for unbounded global group aggregate. + * + * @see [[StreamPhysicalGroupAggregateBase]] for more info. + */ +class StreamPhysicalGlobalGroupAggregate( + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputRel: RelNode, + outputRowType: RelDataType, + val grouping: Array[Int], + val aggCalls: Seq[AggregateCall], + aggCallNeedRetractions: Array[Boolean], + val localAggInputRowType: RelDataType, + needRetraction: Boolean, + val partialFinalType: PartialFinalType) + extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) { + + lazy val localAggInfoList: AggregateInfoList = AggregateUtil.transformToStreamAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(localAggInputRowType), + aggCalls, + aggCallNeedRetractions, + needRetraction, + isStateBackendDataViews = false) + + lazy val globalAggInfoList: AggregateInfoList = AggregateUtil.transformToStreamAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(localAggInputRowType), + aggCalls, + aggCallNeedRetractions, + needRetraction, + isStateBackendDataViews = true) + + override def requireWatermark: Boolean = false + + override def deriveRowType(): RelDataType = outputRowType + + override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { + new StreamPhysicalGlobalGroupAggregate( + cluster, + traitSet, + inputs.get(0), + outputRowType, + grouping, + aggCalls, + aggCallNeedRetractions, + localAggInputRowType, + needRetraction, + partialFinalType) + } + + override def explainTerms(pw: RelWriter): RelWriter = { + super.explainTerms(pw) + .itemIf("groupBy", + RelExplainUtil.fieldToString(grouping, inputRel.getRowType), grouping.nonEmpty) + .itemIf("partialFinalType", partialFinalType, partialFinalType != PartialFinalType.NONE) + .item("select", RelExplainUtil.streamGroupAggregationToString( + inputRel.getRowType, + getRowType, + globalAggInfoList, + grouping, + isGlobal = true)) + } + + override def translateToExecNode(): ExecNode[_] = { + val generateUpdateBefore = ChangelogPlanUtils.generateUpdateBefore(this) + new StreamExecGlobalGroupAggregate( + grouping, + aggCalls.toArray, + aggCallNeedRetractions, + FlinkTypeFactory.toLogicalRowType(localAggInputRowType), + generateUpdateBefore, + needRetraction, + ExecEdge.DEFAULT, + FlinkTypeFactory.toLogicalRowType(getRowType), + getRelDetailedDescription + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala index ce8b563..ee8f074 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala @@ -22,7 +22,7 @@ import org.apache.flink.configuration.ConfigOption import org.apache.flink.configuration.ConfigOptions.key import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory} import org.apache.flink.table.planner.plan.PartialFinalType -import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGlobalGroupAggregate, StreamExecIncrementalGroupAggregate, StreamPhysicalLocalGroupAggregate, StreamPhysicalExchange} +import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecIncrementalGroupAggregate, StreamPhysicalExchange, StreamPhysicalGlobalGroupAggregate, StreamPhysicalLocalGroupAggregate} import org.apache.flink.table.planner.plan.utils.{AggregateInfoList, AggregateUtil, DistinctInfo} import org.apache.flink.util.Preconditions @@ -33,23 +33,24 @@ import java.lang.{Boolean => JBoolean} import java.util.Collections /** - * Rule that matches final [[StreamExecGlobalGroupAggregate]] on [[StreamPhysicalExchange]] - * on final [[StreamPhysicalLocalGroupAggregate]] on partial [[StreamExecGlobalGroupAggregate]], - * and combines the final [[StreamPhysicalLocalGroupAggregate]] and - * the partial [[StreamExecGlobalGroupAggregate]] into a [[StreamExecIncrementalGroupAggregate]]. - */ + * Rule that matches final [[StreamPhysicalGlobalGroupAggregate]] on [[StreamPhysicalExchange]] + * on final [[StreamPhysicalLocalGroupAggregate]] on partial [[StreamPhysicalGlobalGroupAggregate]], + * and combines the final [[StreamPhysicalLocalGroupAggregate]] and + * the partial [[StreamPhysicalGlobalGroupAggregate]] into a + * [[StreamExecIncrementalGroupAggregate]]. + */ class IncrementalAggregateRule extends RelOptRule( - operand(classOf[StreamExecGlobalGroupAggregate], // final global agg + operand(classOf[StreamPhysicalGlobalGroupAggregate], // final global agg operand(classOf[StreamPhysicalExchange], // key by operand(classOf[StreamPhysicalLocalGroupAggregate], // final local agg - operand(classOf[StreamExecGlobalGroupAggregate], any())))), // partial global agg + operand(classOf[StreamPhysicalGlobalGroupAggregate], any())))), // partial global agg "IncrementalAggregateRule") { override def matches(call: RelOptRuleCall): Boolean = { - val finalGlobalAgg: StreamExecGlobalGroupAggregate = call.rel(0) + val finalGlobalAgg: StreamPhysicalGlobalGroupAggregate = call.rel(0) val finalLocalAgg: StreamPhysicalLocalGroupAggregate = call.rel(2) - val partialGlobalAgg: StreamExecGlobalGroupAggregate = call.rel(3) + val partialGlobalAgg: StreamPhysicalGlobalGroupAggregate = call.rel(3) val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig @@ -64,11 +65,11 @@ class IncrementalAggregateRule } override def onMatch(call: RelOptRuleCall): Unit = { - val finalGlobalAgg: StreamExecGlobalGroupAggregate = call.rel(0) + val finalGlobalAgg: StreamPhysicalGlobalGroupAggregate = call.rel(0) val exchange: StreamPhysicalExchange = call.rel(1) val finalLocalAgg: StreamPhysicalLocalGroupAggregate = call.rel(2) - val partialGlobalAgg: StreamExecGlobalGroupAggregate = call.rel(3) - val aggInputRowType = partialGlobalAgg.inputRowType + val partialGlobalAgg: StreamPhysicalGlobalGroupAggregate = call.rel(3) + val aggInputRowType = partialGlobalAgg.localAggInputRowType val partialLocalAggInfoList = partialGlobalAgg.localAggInfoList val partialGlobalAggInfoList = partialGlobalAgg.globalAggInfoList @@ -141,16 +142,6 @@ class IncrementalAggregateRule needInputCount = false, // the local agg is not works on state isStateBackendDataViews = false) - val globalAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( - // the final agg input is partial agg - FlinkTypeFactory.toLogicalRowType(partialGlobalAgg.getRowType), - aggCalls, - // all the aggs do not need retraction - Array.fill(aggCalls.length)(false), - // also do not need count* - needInputCount = false, - // the global agg is works on state - isStateBackendDataViews = true) // check whether the global agg required input row type equals the incr agg output row type val globalAggInputAccType = AggregateUtil.inferLocalAggRowType( @@ -164,15 +155,17 @@ class IncrementalAggregateRule globalAggInputAccType, false)) - new StreamExecGlobalGroupAggregate( + new StreamPhysicalGlobalGroupAggregate( finalGlobalAgg.getCluster, finalGlobalAgg.getTraitSet, newExchange, - finalGlobalAgg.inputRowType, finalGlobalAgg.getRowType, finalGlobalAgg.grouping, - localAggInfoList, // the agg info list is changed - globalAggInfoList, // the agg info list is changed + aggCalls, + // all the aggs do not need retraction + Array.fill(aggCalls.length)(false), + finalGlobalAgg.localAggInputRowType, + needRetraction = false, finalGlobalAgg.partialFinalType) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala index 2f804cc..1399bcf 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala @@ -25,14 +25,13 @@ import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery import org.apache.flink.table.planner.plan.nodes.FlinkConventions import org.apache.flink.table.planner.plan.nodes.physical.stream._ import org.apache.flink.table.planner.plan.rules.physical.FlinkExpandConversionRule._ -import org.apache.flink.table.planner.plan.utils.{AggregateInfoList, AggregateUtil, ChangelogPlanUtils} +import org.apache.flink.table.planner.plan.utils.{AggregateUtil, ChangelogPlanUtils} import org.apache.flink.table.planner.utils.AggregatePhaseStrategy import org.apache.flink.table.planner.utils.TableConfigUtils.getAggPhaseStrategy import org.apache.calcite.plan.RelOptRule.{any, operand} import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.calcite.rel.RelNode -import org.apache.calcite.rel.core.AggregateCall import java.util @@ -46,7 +45,7 @@ import java.util * * and converts them to * {{{ - * StreamExecGlobalGroupAggregate + * StreamPhysicalGlobalGroupAggregate * +- StreamPhysicalExchange * +- StreamPhysicalLocalGroupAggregate * +- input of exchange @@ -93,76 +92,50 @@ class TwoStageOptimizedAggregateRule extends RelOptRule( } override def onMatch(call: RelOptRuleCall): Unit = { - val agg: StreamPhysicalGroupAggregate = call.rel(0) + val originalAgg: StreamPhysicalGroupAggregate = call.rel(0) val realInput: RelNode = call.rel(2) val needRetraction = !ChangelogPlanUtils.isInsertOnly( realInput.asInstanceOf[StreamPhysicalRel]) val fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery) - val monotonicity = fmq.getRelModifiedMonotonicity(agg) + val monotonicity = fmq.getRelModifiedMonotonicity(originalAgg) val aggCallNeedRetractions = AggregateUtil.deriveAggCallNeedRetractions( - agg.grouping.length, agg.aggCalls, needRetraction, monotonicity) - - val globalAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( - FlinkTypeFactory.toLogicalRowType(realInput.getRowType), - agg.aggCalls, - aggCallNeedRetractions, - needRetraction, - isStateBackendDataViews = true) - - val globalHashAgg = createTwoStageAgg( - realInput, agg.aggCalls, aggCallNeedRetractions, needRetraction, globalAggInfoList, agg) - call.transformTo(globalHashAgg) - } - - // the difference between localAggInfos and aggInfos is local agg use heap dataview, - // but global agg use state dataview - private def createTwoStageAgg( - realInput: RelNode, - aggCalls: Seq[AggregateCall], - aggCallNeedRetractions: Array[Boolean], - needRetraction: Boolean, - globalAggInfoList: AggregateInfoList, - agg: StreamPhysicalGroupAggregate): StreamExecGlobalGroupAggregate = { + originalAgg.grouping.length, originalAgg.aggCalls, needRetraction, monotonicity) // local agg shouldn't produce insert only messages val localAggTraitSet = realInput.getTraitSet .plus(ModifyKindSetTrait.INSERT_ONLY) .plus(UpdateKindTrait.NONE) val localHashAgg = new StreamPhysicalLocalGroupAggregate( - agg.getCluster, + originalAgg.getCluster, localAggTraitSet, realInput, - agg.grouping, - agg.aggCalls, + originalAgg.grouping, + originalAgg.aggCalls, aggCallNeedRetractions, needRetraction, - agg.partialFinalType) + originalAgg.partialFinalType) // grouping keys is forwarded by local agg, use indices instead of groupings - val globalGrouping = agg.grouping.indices.toArray + val globalGrouping = originalAgg.grouping.indices.toArray val globalDistribution = createDistribution(globalGrouping) // create exchange if needed val newInput = satisfyDistribution( FlinkConventions.STREAM_PHYSICAL, localHashAgg, globalDistribution) - val globalAggProvidedTraitSet = agg.getTraitSet + val globalAggProvidedTraitSet = originalAgg.getTraitSet - // TODO Temporary solution, remove it later - val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( - FlinkTypeFactory.toLogicalRowType(realInput.getRowType), - aggCalls, - aggCallNeedRetractions, - needRetraction, - isStateBackendDataViews = false) - new StreamExecGlobalGroupAggregate( - agg.getCluster, + val globalAgg = new StreamPhysicalGlobalGroupAggregate( + originalAgg.getCluster, globalAggProvidedTraitSet, newInput, - realInput.getRowType, - agg.getRowType, + originalAgg.getRowType, globalGrouping, - localAggInfoList, - globalAggInfoList, - agg.partialFinalType) + originalAgg.aggCalls, + aggCallNeedRetractions, + realInput.getRowType, + needRetraction, + originalAgg.partialFinalType) + + call.transformTo(globalAgg) } private def createDistribution(keys: Array[Int]): FlinkRelDistribution = { diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala index ec2d55c..aaa5f68 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala @@ -44,7 +44,6 @@ import org.apache.flink.table.planner.plan.nodes.physical.batch._ import org.apache.flink.table.planner.plan.nodes.physical.stream._ import org.apache.flink.table.planner.plan.schema.FlinkPreparingTableBase import org.apache.flink.table.planner.plan.stream.sql.join.TestTemporalTable -import org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToStreamAggregateInfoList import org.apache.flink.table.planner.plan.utils._ import org.apache.flink.table.planner.utils.{CountAggFunction, Top3} import org.apache.flink.table.runtime.operators.rank.{ConstantRankRange, RankType, VariableRankRange} @@ -1033,28 +1032,16 @@ class FlinkRelMdHandlerTestBase { val streamExchange1 = new StreamPhysicalExchange( cluster, streamLocalAgg.getTraitSet.replace(hash0), streamLocalAgg, hash0) - val globalAggInfoList = transformToStreamAggregateInfoList( - FlinkTypeFactory.toLogicalRowType(streamExchange1.getRowType), - aggCalls, - aggCallNeedRetractions, - needInputCount = false, - isStateBackendDataViews = true) - // TODO Temporary solution, remove it later - val localAggInfoList = transformToStreamAggregateInfoList( - FlinkTypeFactory.toLogicalRowType(studentStreamScan.getRowType), - aggCalls, - aggCallNeedRetractions, - needInputCount = false, - isStateBackendDataViews = false) - val streamGlobalAgg = new StreamExecGlobalGroupAggregate( + val streamGlobalAgg = new StreamPhysicalGlobalGroupAggregate( cluster, streamPhysicalTraits, streamExchange1, - streamExchange1.getRowType, rowTypeOfGlobalAgg, Array(0), - localAggInfoList, - globalAggInfoList, + aggCalls, + aggCallNeedRetractions, + streamLocalAgg.getInput.getRowType, + AggregateUtil.needRetraction(streamLocalAgg), PartialFinalType.NONE) val streamExchange2 = new StreamPhysicalExchange(cluster,
