wuchong commented on a change in pull request #8244: [FLINK-11945]
[table-runtime-blink] Support over aggregation for blink streaming runtime
URL: https://github.com/apache/flink/pull/8244#discussion_r281019195
##########
File path:
flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecOverAggregate.scala
##########
@@ -132,8 +143,350 @@ class StreamExecOverAggregate(
replaceInput(ordinalInParent, newInputNode.asInstanceOf[RelNode])
}
- override protected def translateToPlanInternal(
+ override def translateToPlanInternal(
tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = {
- throw new TableException("Implements this")
+ val tableConfig = tableEnv.getConfig
+
+ if (logicWindow.groups.size > 1) {
+ throw new TableException(
+ TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+ "All aggregates must be computed on the same window."))
+ }
+
+ val overWindow: org.apache.calcite.rel.core.Window.Group =
logicWindow.groups.get(0)
+
+ val orderKeys = overWindow.orderKeys.getFieldCollations
+
+ if (orderKeys.size() != 1) {
+ throw new TableException(
+ TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+ "The window can only be ordered by a single time column."))
+ }
+ val orderKey = orderKeys.get(0)
+
+ if (!orderKey.direction.equals(ASCENDING)) {
+ throw new TableException(
+ TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+ "The window can only be ordered in ASCENDING mode."))
+ }
+
+ val inputDS = getInputNodes.get(0).translateToPlan(tableEnv)
+ .asInstanceOf[StreamTransformation[BaseRow]]
+
+ val inputIsAccRetract = StreamExecRetractionRules.isAccRetract(input)
+
+ if (inputIsAccRetract) {
+ throw new TableException(
+ TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+ "Retraction on Over window aggregation is not supported yet. " +
+ "Note: Over window aggregation should not follow a non-windowed
GroupBy aggregation."))
+ }
+
+ if (!logicWindow.groups.get(0).keys.isEmpty &&
tableConfig.getMinIdleStateRetentionTime < 0) {
+ LOG.warn(
+ "No state retention interval configured for a query which accumulates
state. " +
+ "Please provide a query configuration with valid retention interval
to prevent " +
+ "excessive state size. You may specify a retention time of 0 to not
clean up the state.")
+ }
+
+ val timeType =
outputRowType.getFieldList.get(orderKey.getFieldIndex).getType
+
+ // check time field
+ if (!FlinkTypeFactory.isRowtimeIndicatorType(timeType)
+ && !FlinkTypeFactory.isProctimeIndicatorType(timeType)) {
+ throw new TableException(
+ "OVER windows' ordering in stream mode must be defined on a time
attribute.")
+ }
+
+ // identify window rowtime attribute
+ val rowTimeIdx: Option[Int] = if
(FlinkTypeFactory.isRowtimeIndicatorType(timeType)) {
+ Some(orderKey.getFieldIndex)
+ } else if (FlinkTypeFactory.isProctimeIndicatorType(timeType)) {
+ None
+ } else {
+ throw new TableException(
+ TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+ "OVER windows can only be applied on time attributes."))
+ }
+
+ val config = tableEnv.getConfig
+ val codeGenCtx = CodeGeneratorContext(config)
+ val aggregateCalls =
logicWindow.groups.get(0).getAggregateCalls(logicWindow).asScala
+ val isRowsClause = overWindow.isRows
+ val constants = logicWindow.constants.asScala
+ val constantTypes = constants.map(c =>
FlinkTypeFactory.toInternalType(c.getType))
+
+ val fieldNames = inputRowType.getFieldNames.asScala
+ val fieldTypes = inputRowType.getFieldList.asScala
+ .map(c => FlinkTypeFactory.toInternalType(c.getType))
+
+ val inRowType = FlinkTypeFactory.toInternalRowType(inputRel.getRowType)
+ val outRowType = FlinkTypeFactory.toInternalRowType(outputRowType)
+
+ val aggInputType = tableEnv.getTypeFactory.buildRelDataType(
+ fieldNames ++ constants.indices.map(i => "TMP" + i),
+ fieldTypes ++ constantTypes)
+
+ val overProcessFunction = if (overWindow.lowerBound.isPreceding
+ && overWindow.lowerBound.isUnbounded
+ && overWindow.upperBound.isCurrentRow) {
+
+ // unbounded OVER window
+ createUnboundedOverProcessFunction(
+ codeGenCtx,
+ aggregateCalls,
+ constants,
+ aggInputType,
+ rowTimeIdx,
+ isRowsClause,
+ tableConfig,
+ tableEnv.getRelBuilder,
+ config.getNullCheck)
+
+ } else if (overWindow.lowerBound.isPreceding
+ && !overWindow.lowerBound.isUnbounded
+ && overWindow.upperBound.isCurrentRow) {
+
+ val boundValue = OverAggregateUtil.getBoundary(logicWindow,
overWindow.lowerBound)
+
+ if (boundValue.isInstanceOf[BigDecimal]) {
+ throw new TableException(
+ TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+ "the specific value is decimal which haven not supported yet."))
+ }
+ // bounded OVER window
+ val precedingOffset = -1 * boundValue.asInstanceOf[Long] + (if
(isRowsClause) 1 else 0)
+
+ createBoundedOverProcessFunction(
+ codeGenCtx,
+ aggregateCalls,
+ constants,
+ aggInputType,
+ rowTimeIdx,
+ isRowsClause,
+ precedingOffset,
+ tableConfig,
+ tableEnv.getRelBuilder,
+ config.getNullCheck)
+
+ } else {
+ throw new TableException(
+ TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+ "OVER RANGE FOLLOWING windows are not supported yet."))
+ }
+
+ val partitionKeys: Array[Int] = overWindow.keys.toArray
+ val inputTypeInfo = inRowType.toTypeInfo
+
+ val selector = KeySelectorUtil.getBaseRowSelector(partitionKeys,
inputTypeInfo)
+
+ val returnTypeInfo = outRowType.toTypeInfo
+ .asInstanceOf[BaseRowTypeInfo]
+ // partitioned aggregation
+
+ val operator = new KeyedProcessOperator(overProcessFunction)
+
+ val ret = new OneInputTransformation(
+ inputDS,
+ getOperatorName,
+ operator,
+ returnTypeInfo,
+ inputDS.getParallelism)
+
+ if (partitionKeys.isEmpty) {
+ ret.setParallelism(1)
+ ret.setMaxParallelism(1)
+ }
+
+ // set KeyType and Selector for state
+ ret.setStateKeySelector(selector)
+ ret.setStateKeyType(selector.getProducedType)
+ ret
+ }
+
+ /**
+ * Create an ProcessFunction for unbounded OVER window to evaluate final
aggregate value.
+ *
+ * @param ctx code generator context
+ * @param aggregateCalls physical calls to aggregate functions and their
output field names
+ * @param constants the constants in aggregates parameters, such as
sum(1)
+ * @param aggInputType physical type of the input row which consist of
input and constants.
+ * @param rowTimeIdx the index of the rowtime field or None in case of
processing time.
+ * @param isRowsClause it is a tag that indicates whether the OVER clause
is ROWS clause
+ */
+ private def createUnboundedOverProcessFunction(
+ ctx: CodeGeneratorContext,
+ aggregateCalls: Seq[AggregateCall],
+ constants: Seq[RexLiteral],
+ aggInputType: RelDataType,
+ rowTimeIdx: Option[Int],
+ isRowsClause: Boolean,
+ tableConfig: TableConfig,
+ relBuilder: RelBuilder,
+ nullCheck: Boolean): KeyedProcessFunction[BaseRow, BaseRow, BaseRow] = {
+
+ val needRetraction = false
+ val aggInfoList = transformToStreamAggregateInfoList(
+ aggregateCalls,
+ // use aggInputType which considers constants as input instead of
inputSchema.relDataType
+ aggInputType,
+ Array.fill(aggregateCalls.size)(needRetraction),
+ needInputCount = needRetraction,
+ isStateBackendDataViews = true)
+
+ val fieldTypes = inputRowType.getFieldList.asScala.
+ map(c => FlinkTypeFactory.toInternalType(c.getType)).toArray
+
+ val generator = new AggsHandlerCodeGenerator(
+ ctx,
+ relBuilder,
+ fieldTypes,
+ needRetraction,
+ copyInputField = false)
+
+ val genAggsHandler = generator
+ // over agg code gen must pass the constants
+ .withConstants(constants)
+ .generateAggsHandler("UnboundedOverAggregateHelper", aggInfoList)
+
+ val flattenAccTypes = aggInfoList.getAccTypes.map(
+ TypeConverters.createInternalTypeFromTypeInfo)
+
+ if (rowTimeIdx.isDefined) {
+ if (isRowsClause) {
+ // ROWS unbounded over process function
+ new RowTimeUnboundedRowsOver(
+ genAggsHandler,
+ flattenAccTypes,
+ fieldTypes,
+ rowTimeIdx.get,
+ tableConfig)
+ } else {
+ // RANGE unbounded over process function
+ new RowTimeUnboundedRangeOver(
+ genAggsHandler,
+ flattenAccTypes,
+ fieldTypes,
+ rowTimeIdx.get,
+ tableConfig)
+ }
+ } else {
+ new ProcTimeUnboundedOver(
+ genAggsHandler,
+ flattenAccTypes,
+ tableConfig)
+ }
+ }
+
+ /**
+ * Create an ProcessFunction for ROWS clause bounded OVER window to
evaluate final
+ * aggregate value.
+ *
+ * @param ctx code generator context
+ * @param aggregateCalls physical calls to aggregate functions and their
output field names
+ * @param constants the constants in aggregates parameters, such as
sum(1)
+ * @param aggInputType physical type of the input row which consist of
input and constants.
+ * @param rowTimeIdx the index of the rowtime field or None in case of
processing time.
+ * @param isRowsClause it is a tag that indicates whether the OVER clause
is ROWS clause
+ */
+ private def createBoundedOverProcessFunction(
+ ctx: CodeGeneratorContext,
+ aggregateCalls: Seq[AggregateCall],
+ constants: Seq[RexLiteral],
+ aggInputType: RelDataType,
+ rowTimeIdx: Option[Int],
+ isRowsClause: Boolean,
+ precedingOffset: Long,
+ tableConfig: TableConfig,
+ relBuilder: RelBuilder,
+ nullCheck: Boolean): KeyedProcessFunction[BaseRow, BaseRow, BaseRow] = {
+
+ val needRetraction = true
+ val aggInfoList = transformToStreamAggregateInfoList(
+ aggregateCalls,
+ // use aggInputType which considers constants as input instead of
inputSchema.relDataType
+ aggInputType,
+ Array.fill(aggregateCalls.size)(needRetraction),
+ needInputCount = needRetraction,
+ isStateBackendDataViews = true)
+
+ val fieldTypes = inputRowType.getFieldList.asScala.
+ map(c => FlinkTypeFactory.toInternalType(c.getType)).toArray
+
+ val generator = new AggsHandlerCodeGenerator(
+ ctx,
+ relBuilder,
+ fieldTypes,
+ needRetraction,
+ copyInputField = false)
+
+ val genAggsHandler = generator
+ // over agg code gen must pass the constants
+ .withConstants(constants)
+ .generateAggsHandler("BoundedOverAggregateHelper", aggInfoList)
+
+ val flattenAccTypes = aggInfoList.getAccTypes.map(
+ TypeConverters.createInternalTypeFromTypeInfo)
+
+ if (rowTimeIdx.isDefined) {
+ if (isRowsClause) {
+ new RowTimeBoundedRowsOver(
+ genAggsHandler,
+ flattenAccTypes,
+ fieldTypes,
+ precedingOffset,
+ rowTimeIdx.get,
+ tableConfig)
+ } else {
+ new RowTimeBoundedRangeOver(
+ genAggsHandler,
+ flattenAccTypes,
+ fieldTypes,
+ precedingOffset,
+ rowTimeIdx.get,
+ tableConfig)
+ }
+ } else {
+ if (isRowsClause) {
+ new ProcTimeBoundedRowsOver(
+ genAggsHandler,
+ flattenAccTypes,
+ fieldTypes,
+ precedingOffset,
+ tableConfig)
+ } else {
+ new ProcTimeBoundedRangeOver(
+ genAggsHandler,
+ flattenAccTypes,
+ fieldTypes,
+ precedingOffset,
+ tableConfig)
+ }
+ }
+ }
+
+ private def getOperatorName = {
Review comment:
The logic of operator name and `explainTerms` is duplicate. We are planning
to reuse `explainTerms` for operator name. You can simply use `"OverAggregate"`
as the operator name. So that we can avoid to introduce this fat method and
changes in `OverAggregateUtil`.
In order to avoid duplicate
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services