This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 53a888c52499fdd452e41f5e21dc7b4a9630e4f5 Author: godfreyhe <[email protected]> AuthorDate: Thu Dec 24 18:12:47 2020 +0800 [FLINK-20737][table-planner-blink] Use RowType instead of RelDataType when building aggregate info This closes #14478 --- .../batch/BatchExecPythonAggregateRule.java | 5 +- .../batch/BatchExecPythonWindowAggregateRule.java | 5 +- .../table/planner/codegen/MatchCodeGenerator.scala | 2 +- .../batch/BatchExecHashAggregateBase.scala | 2 +- .../batch/BatchExecHashWindowAggregateBase.scala | 2 +- .../physical/batch/BatchExecOverAggregate.scala | 18 ++-- .../batch/BatchExecSortAggregateBase.scala | 2 +- .../batch/BatchExecSortWindowAggregateBase.scala | 2 +- .../physical/stream/StreamExecGroupAggregate.scala | 4 +- .../stream/StreamExecGroupTableAggregateBase.scala | 4 +- .../StreamExecGroupWindowAggregateBase.scala | 2 +- .../physical/stream/StreamExecOverAggregate.scala | 8 +- .../stream/StreamExecPythonGroupAggregate.scala | 4 +- .../physical/batch/BatchExecAggRuleBase.scala | 2 +- .../physical/batch/BatchExecHashAggRule.scala | 4 +- .../batch/BatchExecOverAggregateRule.scala | 4 +- .../physical/batch/BatchExecSortAggRule.scala | 4 +- .../batch/BatchExecWindowAggregateRule.scala | 2 +- .../physical/batch/EnforceLocalAggRuleBase.scala | 5 +- .../physical/stream/IncrementalAggregateRule.scala | 8 +- .../stream/TwoStageOptimizedAggregateRule.scala | 10 +- .../planner/plan/utils/AggFunctionFactory.scala | 46 ++++----- .../table/planner/plan/utils/AggregateUtil.scala | 105 ++++++++++----------- .../plan/metadata/FlinkRelMdHandlerTestBase.scala | 27 ++++-- 24 files changed, 140 insertions(+), 137 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecPythonAggregateRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecPythonAggregateRule.java index 9b69cc2..f1c2ab5 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecPythonAggregateRule.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecPythonAggregateRule.java @@ -21,6 +21,7 @@ package org.apache.flink.table.planner.plan.rules.physical.batch; import org.apache.flink.table.api.TableException; import org.apache.flink.table.functions.UserDefinedFunction; import org.apache.flink.table.functions.python.PythonFunctionKind; +import org.apache.flink.table.planner.calcite.FlinkTypeFactory; import org.apache.flink.table.planner.plan.nodes.FlinkConventions; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate; import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecPythonGroupAggregate; @@ -106,7 +107,9 @@ public class BatchExecPythonAggregateRule extends ConverterRule { Tuple3<int[][], DataType[][], UserDefinedFunction[]> aggBufferTypesAndFunctions = AggregateUtil.transformToBatchAggregateFunctions( - aggCallsWithoutAuxGroupCalls, input.getRowType(), null); + FlinkTypeFactory.toLogicalRowType(input.getRowType()), + aggCallsWithoutAuxGroupCalls, + null); UserDefinedFunction[] aggFunctions = aggBufferTypesAndFunctions._3(); RelTraitSet requiredTraitSet = diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecPythonWindowAggregateRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecPythonWindowAggregateRule.java index 154dc07..958188d 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecPythonWindowAggregateRule.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecPythonWindowAggregateRule.java @@ -22,6 +22,7 @@ import org.apache.flink.table.api.TableException; import org.apache.flink.table.functions.UserDefinedFunction; import org.apache.flink.table.functions.python.PythonFunctionKind; import org.apache.flink.table.planner.calcite.FlinkRelFactories; +import org.apache.flink.table.planner.calcite.FlinkTypeFactory; import org.apache.flink.table.planner.plan.logical.LogicalWindow; import org.apache.flink.table.planner.plan.logical.SessionGroupWindow; import org.apache.flink.table.planner.plan.logical.SlidingGroupWindow; @@ -121,7 +122,9 @@ public class BatchExecPythonWindowAggregateRule extends RelOptRule { Tuple3<int[][], DataType[][], UserDefinedFunction[]> aggBufferTypesAndFunctions = AggregateUtil.transformToBatchAggregateFunctions( - aggCallsWithoutAuxGroupCalls, input.getRowType(), null); + FlinkTypeFactory.toLogicalRowType(input.getRowType()), + aggCallsWithoutAuxGroupCalls, + null); UserDefinedFunction[] aggFunctions = aggBufferTypesAndFunctions._3(); int inputTimeFieldIndex = diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/MatchCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/MatchCodeGenerator.scala index b4f83fd..4297a29 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/MatchCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/MatchCodeGenerator.scala @@ -682,8 +682,8 @@ class MatchCodeGenerator( matchAgg.inputExprs.indices.map(i => s"TMP$i")) val aggInfoList = AggregateUtil.transformToStreamAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(inputRelType), aggCalls, - inputRelType, needRetraction, needInputCount = false, isStateBackendDataViews = false, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashAggregateBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashAggregateBase.scala index 10b1ed7..9400dbd 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashAggregateBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashAggregateBase.scala @@ -114,7 +114,7 @@ abstract class BatchExecHashAggregateBase( val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) val aggInfos = transformToBatchAggregateInfoList( - aggCallToAggFunction.map(_._1), aggInputRowType) + FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1)) var managedMemory: Long = 0L val generatedOperator = if (grouping.isEmpty) { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashWindowAggregateBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashWindowAggregateBase.scala index 595c0b8..c7e1f12 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashWindowAggregateBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashWindowAggregateBase.scala @@ -113,7 +113,7 @@ abstract class BatchExecHashWindowAggregateBase( val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) val aggInfos = transformToBatchAggregateInfoList( - aggCallToAggFunction.map(_._1), aggInputRowType) + FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1)) val groupBufferLimitSize = config.getConfiguration.getInteger( ExecutionConfigOptions.TABLE_EXEC_WINDOW_AGG_BUFFER_SIZE_LIMIT) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecOverAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecOverAggregate.scala index 82329f5..04e1c3f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecOverAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecOverAggregate.scala @@ -140,10 +140,10 @@ class BatchExecOverAggregate( //operator needn't cache data val aggHandlers = modeToGroupToAggCallToAggFunction.map { case (_, _, aggCallToAggFunction) => val aggInfoList = transformToBatchAggregateInfoList( - aggCallToAggFunction.map(_._1), // use aggInputType which considers constants as input instead of inputType - inputTypeWithConstants, - orderKeyIndices) + FlinkTypeFactory.toLogicalRowType(inputTypeWithConstants), + aggCallToAggFunction.map(_._1), + orderKeyIndexes = orderKeyIndices) val codeGenCtx = CodeGeneratorContext(config) val generator = new AggsHandlerCodeGenerator( codeGenCtx, @@ -191,10 +191,10 @@ class BatchExecOverAggregate( //lies on the offset of the window frame. aggCallToAggFunction.map { case (aggCall, _) => val aggInfoList = transformToBatchAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(inputTypeWithConstants), Seq(aggCall), - inputTypeWithConstants, - orderKeyIndices, - Array[Boolean](true) /* needRetraction = true, See LeadLagAggFunction */) + Array[Boolean](true), /* needRetraction = true, See LeadLagAggFunction */ + orderKeyIndexes = orderKeyIndices) val generator = new AggsHandlerCodeGenerator( CodeGeneratorContext(config), @@ -263,10 +263,10 @@ class BatchExecOverAggregate( case _ => val aggInfoList = transformToBatchAggregateInfoList( - aggCallToAggFunction.map(_._1), //use aggInputType which considers constants as input instead of inputSchema.relDataType - inputTypeWithConstants, - orderKeyIndices) + FlinkTypeFactory.toLogicalRowType(inputTypeWithConstants), + aggCallToAggFunction.map(_._1), + orderKeyIndexes = orderKeyIndices) val codeGenCtx = CodeGeneratorContext(config) val generator = new AggsHandlerCodeGenerator( codeGenCtx, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregateBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregateBase.scala index bf2dec4..12a617f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregateBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregateBase.scala @@ -95,7 +95,7 @@ abstract class BatchExecSortAggregateBase( val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) val aggInfos = transformToBatchAggregateInfoList( - aggCallToAggFunction.map(_._1), aggInputRowType) + FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1)) val generatedOperator = if (grouping.isEmpty) { AggWithoutKeysCodeGenerator.genWithoutKeys( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortWindowAggregateBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortWindowAggregateBase.scala index f919309..64c1b0e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortWindowAggregateBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortWindowAggregateBase.scala @@ -101,7 +101,7 @@ abstract class BatchExecSortWindowAggregateBase( val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) val aggInfos = transformToBatchAggregateInfoList( - aggCallToAggFunction.map(_._1), aggInputRowType) + FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1)) val groupBufferLimitSize = planner.getTableConfig.getConfiguration.getInteger( ExecutionConfigOptions.TABLE_EXEC_WINDOW_AGG_BUFFER_SIZE_LIMIT) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupAggregate.scala index 210ee48..f7ae1d2 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupAggregate.scala @@ -63,8 +63,8 @@ class StreamExecGroupAggregate( val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList( this, - aggCalls, - grouping) + grouping.length, + aggCalls) override def requireWatermark: Boolean = false diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupTableAggregateBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupTableAggregateBase.scala index 4c02b7d..c7c8a9b 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupTableAggregateBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupTableAggregateBase.scala @@ -42,8 +42,8 @@ abstract class StreamExecGroupTableAggregateBase( val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList( this, - aggCalls, - grouping) + grouping.length, + aggCalls) override def requireWatermark: Boolean = false diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupWindowAggregateBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupWindowAggregateBase.scala index b8c0aa2..47b7b08 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupWindowAggregateBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupWindowAggregateBase.scala @@ -137,8 +137,8 @@ abstract class StreamExecGroupWindowAggregateBase( val needRetraction = !ChangelogPlanUtils.inputInsertOnly(this) val aggInfoList = transformToStreamAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(inputRowType), aggCalls, - inputRowType, Array.fill(aggCalls.size)(needRetraction), needInputCount = needRetraction, isStateBackendDataViews = true) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecOverAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecOverAggregate.scala index fecd78e..494225f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecOverAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecOverAggregate.scala @@ -243,9 +243,9 @@ class StreamExecOverAggregate( val needRetraction = false val aggInfoList = transformToStreamAggregateInfoList( - aggregateCalls, // use aggInputType which considers constants as input instead of inputSchema.relDataType - aggInputType, + FlinkTypeFactory.toLogicalRowType(aggInputType), + aggregateCalls, Array.fill(aggregateCalls.size)(needRetraction), needInputCount = needRetraction, isStateBackendDataViews = true) @@ -322,9 +322,9 @@ class StreamExecOverAggregate( val needRetraction = true val aggInfoList = transformToStreamAggregateInfoList( - aggregateCalls, // use aggInputType which considers constants as input instead of inputSchema.relDataType - aggInputType, + FlinkTypeFactory.toLogicalRowType(aggInputType), + aggregateCalls, Array.fill(aggregateCalls.size)(needRetraction), needInputCount = needRetraction, isStateBackendDataViews = true) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupAggregate.scala index 4618183..aabf675 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupAggregate.scala @@ -59,8 +59,8 @@ class StreamExecPythonGroupAggregate( val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList( this, - aggCalls, - grouping) + grouping.length, + aggCalls) override def requireWatermark: Boolean = false diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecAggRuleBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecAggRuleBase.scala index 49747b4..7d3a190 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecAggRuleBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecAggRuleBase.scala @@ -154,7 +154,7 @@ trait BatchExecAggRuleBase { protected def isAggBufferFixedLength(agg: Aggregate): Boolean = { val (_, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg) val (_, aggBufferTypes, _) = AggregateUtil.transformToBatchAggregateFunctions( - aggCallsWithoutAuxGroupCalls, agg.getInput.getRowType) + FlinkTypeFactory.toLogicalRowType(agg.getInput.getRowType), aggCallsWithoutAuxGroupCalls) isAggBufferFixedLength(aggBufferTypes.map(_.map(fromDataTypeToLogicalType))) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashAggRule.scala index 76b3ef5..1a28774 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashAggRule.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.planner.plan.rules.physical.batch import org.apache.flink.table.api.config.OptimizerConfigOptions -import org.apache.flink.table.planner.calcite.FlinkContext +import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory} import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.nodes.FlinkConventions import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate @@ -87,7 +87,7 @@ class BatchExecHashAggRule val (auxGroupSet, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg) val (_, aggBufferTypes, aggFunctions) = AggregateUtil.transformToBatchAggregateFunctions( - aggCallsWithoutAuxGroupCalls, inputRowType) + FlinkTypeFactory.toLogicalRowType(inputRowType), aggCallsWithoutAuxGroupCalls) val aggCallToAggFunction = aggCallsWithoutAuxGroupCalls.zip(aggFunctions) val aggProvidedTraitSet = agg.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecOverAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecOverAggregateRule.scala index 58d1c45..091dbb2 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecOverAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecOverAggregateRule.scala @@ -94,7 +94,9 @@ class BatchExecOverAggregateRule val groupToAggCallToAggFunction = groupBuffer.map { group => val aggregateCalls = group.getAggregateCalls(logicWindow) val (_, _, aggregates) = AggregateUtil.transformToBatchAggregateFunctions( - aggregateCalls, inputTypeWithConstants, orderKeyIndexes) + FlinkTypeFactory.toLogicalRowType(inputTypeWithConstants), + aggregateCalls, + orderKeyIndexes) val aggCallToAggFunction = aggregateCalls.zip(aggregates) (group, aggCallToAggFunction) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala index c164228..7bf6d16 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.planner.plan.rules.physical.batch import org.apache.flink.table.api.config.OptimizerConfigOptions -import org.apache.flink.table.planner.calcite.FlinkContext +import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory} import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.nodes.FlinkConventions import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate @@ -80,7 +80,7 @@ class BatchExecSortAggRule val (auxGroupSet, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg) val (_, aggBufferTypes, aggFunctions) = AggregateUtil.transformToBatchAggregateFunctions( - aggCallsWithoutAuxGroupCalls, inputRowType) + FlinkTypeFactory.toLogicalRowType(inputRowType), aggCallsWithoutAuxGroupCalls) val groupSet = agg.getGroupSet.toArray val aggCallToAggFunction = aggCallsWithoutAuxGroupCalls.zip(aggFunctions) // TODO aggregate include projection now, so do not provide new trait will be safe diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala index e179187..81151f2 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala @@ -101,7 +101,7 @@ class BatchExecWindowAggregateRule val (auxGroupSet, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg) val (_, aggBufferTypes, aggregates) = AggregateUtil.transformToBatchAggregateFunctions( - aggCallsWithoutAuxGroupCalls, input.getRowType) + FlinkTypeFactory.toLogicalRowType(input.getRowType), aggCallsWithoutAuxGroupCalls) val aggCallToAggFunction = aggCallsWithoutAuxGroupCalls.zip(aggregates) val internalAggBufferTypes = aggBufferTypes.map(_.map(fromDataTypeToLogicalType)) val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala index 1ff432e..34a4de2 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala @@ -19,9 +19,10 @@ package org.apache.flink.table.planner.plan.rules.physical.batch import org.apache.flink.table.api.TableException +import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.nodes.FlinkConventions -import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalExpand, BatchExecGroupAggregateBase, BatchExecHashAggregate, BatchExecSortAggregate, BatchPhysicalExchange} +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecHashAggregate, BatchExecSortAggregate, BatchPhysicalExchange, BatchPhysicalExpand} import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil} import org.apache.calcite.plan.{RelOptRule, RelOptRuleOperand} @@ -75,7 +76,7 @@ abstract class EnforceLocalAggRuleBase( val aggCallToAggFunction = completeAgg.getAggCallToAggFunction val (_, aggBufferTypes, _) = AggregateUtil.transformToBatchAggregateFunctions( - aggCalls, inputRowType) + FlinkTypeFactory.toLogicalRowType(inputRowType), aggCalls) val traitSet = cluster.getPlanner .emptyTraitSet diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala index a9e8628..a0bb8fd 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala @@ -132,9 +132,9 @@ class IncrementalAggregateRule } else { // an additional count1 is inserted, need to adapt the global agg val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( - aggCalls, // the final agg input is partial agg - partialGlobalAgg.getRowType, + FlinkTypeFactory.toLogicalRowType(partialGlobalAgg.getRowType), + aggCalls, // all the aggs do not need retraction Array.fill(aggCalls.length)(false), // also do not need count* @@ -142,9 +142,9 @@ class IncrementalAggregateRule // the local agg is not works on state isStateBackendDataViews = false) val globalAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( - aggCalls, // the final agg input is partial agg - partialGlobalAgg.getRowType, + FlinkTypeFactory.toLogicalRowType(partialGlobalAgg.getRowType), + aggCalls, // all the aggs do not need retraction Array.fill(aggCalls.length)(false), // also do not need count* diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala index 11a6281..afb24fb 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala @@ -66,11 +66,11 @@ class TwoStageOptimizedAggregateRule extends RelOptRule( val fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery) val monotonicity = fmq.getRelModifiedMonotonicity(agg) val needRetractionArray = AggregateUtil.getNeedRetractions( - agg.grouping.length, needRetraction, monotonicity, agg.aggCalls) + agg.grouping.length, agg.aggCalls, needRetraction, monotonicity) val aggInfoList = AggregateUtil.transformToStreamAggregateInfoList( + FlinkTypeFactory.toLogicalRowType( agg.getInput.getRowType), agg.aggCalls, - agg.getInput.getRowType, needRetractionArray, needRetraction, isStateBackendDataViews = true) @@ -98,18 +98,18 @@ class TwoStageOptimizedAggregateRule extends RelOptRule( val fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery) val monotonicity = fmq.getRelModifiedMonotonicity(agg) val needRetractionArray = AggregateUtil.getNeedRetractions( - agg.grouping.length, needRetraction, monotonicity, agg.aggCalls) + agg.grouping.length, agg.aggCalls, needRetraction, monotonicity) val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(realInput.getRowType), agg.aggCalls, - realInput.getRowType, needRetractionArray, needRetraction, isStateBackendDataViews = false) val globalAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(realInput.getRowType), agg.aggCalls, - realInput.getRowType, needRetractionArray, needRetraction, isStateBackendDataViews = true) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala index 798e256..bfe82e6 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala @@ -19,24 +19,17 @@ package org.apache.flink.table.planner.plan.utils import org.apache.flink.table.api.TableException import org.apache.flink.table.functions.UserDefinedFunction -import org.apache.flink.table.planner.calcite.FlinkTypeFactory -import org.apache.flink.table.planner.functions.aggfunctions.FirstValueAggFunction._ -import org.apache.flink.table.planner.functions.aggfunctions.FirstValueWithRetractAggFunction._ import org.apache.flink.table.planner.functions.aggfunctions.IncrSumAggFunction._ import org.apache.flink.table.planner.functions.aggfunctions.IncrSumWithRetractAggFunction._ -import org.apache.flink.table.planner.functions.aggfunctions.LastValueAggFunction._ -import org.apache.flink.table.planner.functions.aggfunctions.LastValueWithRetractAggFunction._ import org.apache.flink.table.planner.functions.aggfunctions.SingleValueAggFunction._ import org.apache.flink.table.planner.functions.aggfunctions.SumWithRetractAggFunction._ import org.apache.flink.table.planner.functions.aggfunctions._ import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction import org.apache.flink.table.planner.functions.sql.{SqlFirstLastValueAggFunction, SqlListAggFunction} import org.apache.flink.table.planner.functions.utils.AggSqlFunction -import org.apache.flink.table.runtime.typeutils.DecimalDataTypeInfo import org.apache.flink.table.types.logical.LogicalTypeRoot._ import org.apache.flink.table.types.logical._ -import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.sql.fun._ import org.apache.calcite.sql.{SqlAggFunction, SqlKind, SqlRankFunction} @@ -49,14 +42,14 @@ import scala.collection.JavaConversions._ * The class of agg function factory which is used to create AggregateFunction or * DeclarativeAggregateFunction from Calcite AggregateCall * - * @param inputType the input rel data type - * @param orderKeyIdx the indexes of order key (null when is not over agg) - * @param needRetraction true if need retraction + * @param inputRowType the input's output RowType + * @param orderKeyIndexes the indexes of order key (null when is not over agg) + * @param aggCallNeedRetractions true if need retraction */ class AggFunctionFactory( - inputType: RelDataType, - orderKeyIdx: Array[Int], - needRetraction: Array[Boolean]) { + inputRowType: RowType, + orderKeyIndexes: Array[Int], + aggCallNeedRetractions: Array[Boolean]) { /** * The entry point to create an aggregate function from the given AggregateCall @@ -64,8 +57,7 @@ class AggFunctionFactory( def createAggFunction(call: AggregateCall, index: Int): UserDefinedFunction = { val argTypes: Array[LogicalType] = call.getArgList - .map(inputType.getFieldList.get(_).getType) - .map(FlinkTypeFactory.toLogicalType) + .map(inputRowType.getChildren.get(_)) .toArray call.getAggregation match { @@ -165,7 +157,7 @@ class AggFunctionFactory( private def createSumAggFunction( argTypes: Array[LogicalType], index: Int): UserDefinedFunction = { - if (needRetraction(index)) { + if (aggCallNeedRetractions(index)) { argTypes(0).getTypeRoot match { case TINYINT => new ByteSumWithRetractAggFunction @@ -236,7 +228,7 @@ class AggFunctionFactory( private def createIncrSumAggFunction( argTypes: Array[LogicalType], index: Int): UserDefinedFunction = { - if (needRetraction(index)) { + if (aggCallNeedRetractions(index)) { argTypes(0).getTypeRoot match { case TINYINT => new ByteIncrSumWithRetractAggFunction @@ -286,7 +278,7 @@ class AggFunctionFactory( index: Int) : UserDefinedFunction = { val valueType = argTypes(0) - if (needRetraction(index)) { + if (aggCallNeedRetractions(index)) { valueType.getTypeRoot match { case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL | TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE => @@ -370,7 +362,7 @@ class AggFunctionFactory( index: Int) : UserDefinedFunction = { val valueType = argTypes(0) - if (needRetraction(index)) { + if (aggCallNeedRetractions(index)) { valueType.getTypeRoot match { case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL | TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE => @@ -460,16 +452,12 @@ class AggFunctionFactory( } private def createRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = { - val argTypes = orderKeyIdx - .map(inputType.getFieldList.get(_).getType) - .map(FlinkTypeFactory.toLogicalType) + val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_)) new RankAggFunction(argTypes) } private def createDenseRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = { - val argTypes = orderKeyIdx - .map(inputType.getFieldList.get(_).getType) - .map(FlinkTypeFactory.toLogicalType) + val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_)) new DenseRankAggFunction(argTypes) } @@ -478,7 +466,7 @@ class AggFunctionFactory( index: Int) : UserDefinedFunction = { val valueType = argTypes(0) - if (needRetraction(index)) { + if (aggCallNeedRetractions(index)) { valueType.getTypeRoot match { case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL => new FirstValueWithRetractAggFunction(valueType) @@ -502,7 +490,7 @@ class AggFunctionFactory( index: Int) : UserDefinedFunction = { val valueType = argTypes(0) - if (needRetraction(index)) { + if (aggCallNeedRetractions(index)) { valueType.getTypeRoot match { case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL => new LastValueWithRetractAggFunction(valueType) @@ -524,7 +512,7 @@ class AggFunctionFactory( private def createListAggFunction( argTypes: Array[LogicalType], index: Int): UserDefinedFunction = { - if (needRetraction(index)) { + if (aggCallNeedRetractions(index)) { new ListAggWithRetractAggFunction } else { new ListAggFunction(1) @@ -534,7 +522,7 @@ class AggFunctionFactory( private def createListAggWsFunction( argTypes: Array[LogicalType], index: Int): UserDefinedFunction = { - if (needRetraction(index)) { + if (aggCallNeedRetractions(index)) { new ListAggWsWithRetractAggFunction } else { new ListAggFunction(2) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index d3bf852..01fd4d4 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -151,12 +151,12 @@ object AggregateUtil extends Enumeration { def getOutputIndexToAggCallIndexMap( aggregateCalls: Seq[AggregateCall], inputType: RelDataType, - orderKeyIdx: Array[Int] = null): util.Map[Integer, Integer] = { + orderKeyIndexes: Array[Int] = null): util.Map[Integer, Integer] = { val aggInfos = transformToAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(inputType), aggregateCalls, - inputType, - orderKeyIdx, Array.fill(aggregateCalls.size)(false), + orderKeyIndexes, needInputCount = false, isStateBackedDataViews = false, needDistinctInfo = false).aggInfos @@ -176,10 +176,10 @@ object AggregateUtil extends Enumeration { } def deriveAggregateInfoList( - aggNode: StreamPhysicalRel, - aggCalls: Seq[AggregateCall], - grouping: Array[Int]): AggregateInfoList = { - val input = aggNode.getInput(0) + agg: StreamPhysicalRel, + groupCount: Int, + aggCalls: Seq[AggregateCall]): AggregateInfoList = { + val input = agg.getInput(0) // need to call `retract()` if input contains update or delete val modifyKindSetTrait = input.getTraitSet.getTrait(ModifyKindSetTraitDef.INSTANCE) val needRetraction = if (modifyKindSetTrait == null) { @@ -188,29 +188,28 @@ object AggregateUtil extends Enumeration { } else { !modifyKindSetTrait.modifyKindSet.isInsertOnly } - val fmq = FlinkRelMetadataQuery.reuseOrCreate(aggNode.getCluster.getMetadataQuery) - val monotonicity = fmq.getRelModifiedMonotonicity(aggNode) - val needRetractionArray = AggregateUtil.getNeedRetractions( - grouping.length, needRetraction, monotonicity, aggCalls) - AggregateUtil.transformToStreamAggregateInfoList( + val fmq = FlinkRelMetadataQuery.reuseOrCreate(agg.getCluster.getMetadataQuery) + val monotonicity = fmq.getRelModifiedMonotonicity(agg) + val needRetractionArray = getNeedRetractions(groupCount, aggCalls, needRetraction, monotonicity) + transformToStreamAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(input.getRowType), aggCalls, - input.getRowType, needRetractionArray, needInputCount = needRetraction, isStateBackendDataViews = true) } def transformToBatchAggregateFunctions( + inputRowType: RowType, aggregateCalls: Seq[AggregateCall], - inputRowType: RelDataType, - orderKeyIdx: Array[Int] = null) + orderKeyIndexes: Array[Int] = null) : (Array[Array[Int]], Array[Array[DataType]], Array[UserDefinedFunction]) = { val aggInfos = transformToAggregateInfoList( - aggregateCalls, inputRowType, - orderKeyIdx, + aggregateCalls, Array.fill(aggregateCalls.size)(false), + orderKeyIndexes, needInputCount = false, isStateBackedDataViews = false, needDistinctInfo = false).aggInfos @@ -223,39 +222,39 @@ object AggregateUtil extends Enumeration { } def transformToBatchAggregateInfoList( - aggregateCalls: Seq[AggregateCall], - inputRowType: RelDataType, - orderKeyIdx: Array[Int] = null, - needRetractions: Array[Boolean] = null): AggregateInfoList = { + inputRowType: RowType, + aggCalls: Seq[AggregateCall], + aggCallNeedRetractions: Array[Boolean] = null, + orderKeyIndexes: Array[Int] = null): AggregateInfoList = { - val needRetractionArray = if (needRetractions == null) { - Array.fill(aggregateCalls.size)(false) + val finalAggCallNeedRetractions = if (aggCallNeedRetractions == null) { + Array.fill(aggCalls.size)(false) } else { - needRetractions + aggCallNeedRetractions } transformToAggregateInfoList( - aggregateCalls, inputRowType, - orderKeyIdx, - needRetractionArray, + aggCalls, + finalAggCallNeedRetractions, + orderKeyIndexes, needInputCount = false, isStateBackedDataViews = false, needDistinctInfo = false) } def transformToStreamAggregateInfoList( + inputRowType: RowType, aggregateCalls: Seq[AggregateCall], - inputRowType: RelDataType, - needRetraction: Array[Boolean], + aggCallNeedRetractions: Array[Boolean], needInputCount: Boolean, isStateBackendDataViews: Boolean, needDistinctInfo: Boolean = true): AggregateInfoList = { transformToAggregateInfoList( - aggregateCalls, inputRowType, - orderKeyIdx = null, - needRetraction ++ Array(needInputCount), // for additional count(*) + aggregateCalls, + aggCallNeedRetractions ++ Array(needInputCount), // for additional count(*) + orderKeyIndexes = null, needInputCount, isStateBackendDataViews, needDistinctInfo) @@ -264,10 +263,10 @@ object AggregateUtil extends Enumeration { /** * Transforms calcite aggregate calls to AggregateInfos. * + * @param inputRowType the input's output RowType * @param aggregateCalls the calcite aggregate calls - * @param inputRowType the input rel data type - * @param orderKeyIdx the index of order by field in the input, null if not over agg - * @param needRetraction whether the aggregate function need retract method + * @param aggCallNeedRetractions whether the aggregate function need retract method + * @param orderKeyIndexes the index of order by field in the input, null if not over agg * @param needInputCount whether need to calculate the input counts, which is used in * aggregation with retraction input.If needed, * insert a count(1) aggregate into the agg list. @@ -275,10 +274,10 @@ object AggregateUtil extends Enumeration { * @param needDistinctInfo whether need to extract distinct information */ private def transformToAggregateInfoList( + inputRowType: RowType, aggregateCalls: Seq[AggregateCall], - inputRowType: RelDataType, - orderKeyIdx: Array[Int], - needRetraction: Array[Boolean], + aggCallNeedRetractions: Array[Boolean], + orderKeyIndexes: Array[Int], needInputCount: Boolean, isStateBackedDataViews: Boolean, needDistinctInfo: Boolean): AggregateInfoList = { @@ -301,12 +300,12 @@ object AggregateUtil extends Enumeration { // Step-3: // create aggregate information - val factory = new AggFunctionFactory(inputRowType, orderKeyIdx, needRetraction) + val factory = new AggFunctionFactory(inputRowType, orderKeyIndexes, aggCallNeedRetractions) val aggInfos = newAggCalls .zipWithIndex .map { case (call, index) => val argIndexes = call.getAggregation match { - case _: SqlRankFunction => orderKeyIdx + case _: SqlRankFunction => orderKeyIndexes case _ => call.getArgList.map(_.intValue()).toArray } transformToAggregateInfo( @@ -316,14 +315,14 @@ object AggregateUtil extends Enumeration { argIndexes, factory.createAggFunction(call, index), isStateBackedDataViews, - needRetraction(index)) + aggCallNeedRetractions(index)) } AggregateInfoList(aggInfos.toArray, indexOfCountStar, countStarInserted, distinctInfos) } private def transformToAggregateInfo( - inputRowRelDataType: RelDataType, + inputRowType: RowType, call: AggregateCall, index: Int, argIndexes: Array[Int], @@ -334,7 +333,7 @@ object AggregateUtil extends Enumeration { case _: BridgingSqlAggFunction => createAggregateInfoFromBridgingFunction( - inputRowRelDataType, + inputRowType, call, index, argIndexes, @@ -344,7 +343,7 @@ object AggregateUtil extends Enumeration { case _: AggSqlFunction => createAggregateInfoFromLegacyFunction( - inputRowRelDataType, + inputRowType, call, index, argIndexes, @@ -363,7 +362,7 @@ object AggregateUtil extends Enumeration { } private def createAggregateInfoFromBridgingFunction( - inputRowRelDataType: RelDataType, + inputRowType: RowType, call: AggregateCall, index: Int, argIndexes: Array[Int], @@ -387,7 +386,7 @@ object AggregateUtil extends Enumeration { function.getTypeFactory, function, SqlTypeUtil.projectTypes( - inputRowRelDataType, + FlinkTypeFactory.INSTANCE.buildRelNodeRowType(inputRowType), argIndexes.map(Int.box).toList), 0, false)) @@ -490,7 +489,7 @@ object AggregateUtil extends Enumeration { } private def createAggregateInfoFromLegacyFunction( - inputRowRelDataType: RelDataType, + inputRowType: RowType, call: AggregateCall, index: Int, argIndexes: Array[Int], @@ -507,8 +506,7 @@ object AggregateUtil extends Enumeration { } val externalAccType = getAccumulatorTypeOfAggregateFunction(a, implicitAccType) val argTypes = call.getArgList - .map(idx => inputRowRelDataType.getFieldList.get(idx).getType) - .map(FlinkTypeFactory.toLogicalType) + .map(idx => inputRowType.getChildren.get(idx)) val externalArgTypes: Array[DataType] = getAggUserDefinedInputTypes( a, externalAccType, @@ -605,7 +603,7 @@ object AggregateUtil extends Enumeration { private def extractDistinctInformation( needDistinctInfo: Boolean, aggCalls: Seq[AggregateCall], - inputType: RelDataType, + inputType: RowType, hasStateBackedDataViews: Boolean, consumeRetraction: Boolean): (Array[DistinctInfo], Seq[AggregateCall]) = { @@ -621,8 +619,7 @@ object AggregateUtil extends Enumeration { if (call.isDistinct && !call.isApproximate && argIndexes.length > 0) { val argTypes: Array[LogicalType] = call .getArgList - .map(inputType.getFieldList.get(_).getType) - .map(FlinkTypeFactory.toLogicalType) + .map(inputType.getChildren.get(_)) .toArray val keyType = createDistinctKeyType(argTypes) @@ -790,9 +787,9 @@ object AggregateUtil extends Enumeration { */ def getNeedRetractions( groupCount: Int, + aggCalls: Seq[AggregateCall], needRetraction: Boolean, - monotonicity: RelModifiedMonotonicity, - aggCalls: Seq[AggregateCall]): Array[Boolean] = { + monotonicity: RelModifiedMonotonicity): Array[Boolean] = { val needRetractionArray = Array.fill(aggCalls.size)(needRetraction) if (monotonicity != null && needRetraction) { aggCalls.zipWithIndex.foreach { case (aggCall, idx) => diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala index e92b095..eaa38bf 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala @@ -952,7 +952,9 @@ class FlinkRelMdHandlerTestBase { val aggCalls = logicalAgg.getAggCallList val aggFunctionFactory = new AggFunctionFactory( - studentBatchScan.getRowType, Array.empty[Int], Array.fill(aggCalls.size())(false)) + FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType), + Array.empty[Int], + Array.fill(aggCalls.size())(false)) val aggCallToAggFunction = aggCalls.zipWithIndex.map { case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) } @@ -1018,11 +1020,11 @@ class FlinkRelMdHandlerTestBase { isMerge = false) val needRetractionArray = AggregateUtil.getNeedRetractions( - 1, needRetraction = false, null, aggCalls) + 1, aggCalls, needRetraction = false, null) val localAggInfoList = transformToStreamAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(studentStreamScan.getRowType), aggCalls, - studentStreamScan.getRowType, needRetractionArray, needInputCount = false, isStateBackendDataViews = false) @@ -1039,8 +1041,8 @@ class FlinkRelMdHandlerTestBase { val streamExchange1 = new StreamPhysicalExchange( cluster, streamLocalAgg.getTraitSet.replace(hash0), streamLocalAgg, hash0) val globalAggInfoList = transformToStreamAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(streamExchange1.getRowType), aggCalls, - streamExchange1.getRowType, needRetractionArray, needInputCount = false, isStateBackendDataViews = true) @@ -1103,7 +1105,9 @@ class FlinkRelMdHandlerTestBase { call => call.getAggregation != FlinkSqlOperatorTable.AUXILIARY_GROUP } val aggFunctionFactory = new AggFunctionFactory( - studentBatchScan.getRowType, Array.empty[Int], Array.fill(aggCalls.size())(false)) + FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType), + Array.empty[Int], + Array.fill(aggCalls.size())(false)) val aggCallToAggFunction = aggCalls.zipWithIndex.map { case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) } @@ -1245,7 +1249,8 @@ class FlinkRelMdHandlerTestBase { cluster, batchPhysicalTraits.replace(hash01), batchCalc, hash01) val (_, _, aggregates) = AggregateUtil.transformToBatchAggregateFunctions( - flinkLogicalWindowAgg.getAggCallList, batchExchange1.getRowType) + FlinkTypeFactory.toLogicalRowType(batchExchange1.getRowType), + flinkLogicalWindowAgg.getAggCallList) val aggCallToAggFunction = flinkLogicalWindowAgg.getAggCallList.zip(aggregates) val localWindowAggTypes = @@ -1390,7 +1395,8 @@ class FlinkRelMdHandlerTestBase { cluster, batchPhysicalTraits.replace(hash1), batchCalc, hash1) val (_, _, aggregates) = AggregateUtil.transformToBatchAggregateFunctions( - flinkLogicalWindowAgg.getAggCallList, batchExchange1.getRowType) + FlinkTypeFactory.toLogicalRowType(batchExchange1.getRowType), + flinkLogicalWindowAgg.getAggCallList) val aggCallToAggFunction = flinkLogicalWindowAgg.getAggCallList.zip(aggregates) val localWindowAggTypes = @@ -1538,7 +1544,8 @@ class FlinkRelMdHandlerTestBase { val aggCallsWithoutAuxGroup = flinkLogicalWindowAggWithAuxGroup.getAggCallList.drop(1) val (_, _, aggregates) = AggregateUtil.transformToBatchAggregateFunctions( - aggCallsWithoutAuxGroup, batchExchange1.getRowType) + FlinkTypeFactory.toLogicalRowType(batchExchange1.getRowType), + aggCallsWithoutAuxGroup) val aggCallToAggFunction = aggCallsWithoutAuxGroup.zip(aggregates) val localWindowAggTypes = @@ -2438,7 +2445,9 @@ class FlinkRelMdHandlerTestBase { ).build().asInstanceOf[LogicalAggregate] val aggCalls = logicalAgg.getAggCallList val aggFunctionFactory = new AggFunctionFactory( - studentBatchScan.getRowType, Array.empty[Int], Array.fill(aggCalls.size())(false)) + FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType), + Array.empty[Int], + Array.fill(aggCalls.size())(false)) val aggCallToAggFunction = aggCalls.zipWithIndex.map { case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) }
