wuchong commented on a change in pull request #8244: [FLINK-11945] 
[table-runtime-blink] Support over aggregation for blink streaming runtime
URL: https://github.com/apache/flink/pull/8244#discussion_r281019195
 
 

 ##########
 File path: 
flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecOverAggregate.scala
 ##########
 @@ -132,8 +143,350 @@ class StreamExecOverAggregate(
     replaceInput(ordinalInParent, newInputNode.asInstanceOf[RelNode])
   }
 
-  override protected def translateToPlanInternal(
+  override def translateToPlanInternal(
       tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = {
-    throw new TableException("Implements this")
+    val tableConfig = tableEnv.getConfig
+
+    if (logicWindow.groups.size > 1) {
+      throw new TableException(
+        TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+          "All aggregates must be computed on the same window."))
+    }
+
+    val overWindow: org.apache.calcite.rel.core.Window.Group = 
logicWindow.groups.get(0)
+
+    val orderKeys = overWindow.orderKeys.getFieldCollations
+
+    if (orderKeys.size() != 1) {
+      throw new TableException(
+        TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+          "The window can only be ordered by a single time column."))
+    }
+    val orderKey = orderKeys.get(0)
+
+    if (!orderKey.direction.equals(ASCENDING)) {
+      throw new TableException(
+        TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+          "The window can only be ordered in ASCENDING mode."))
+    }
+
+    val inputDS = getInputNodes.get(0).translateToPlan(tableEnv)
+      .asInstanceOf[StreamTransformation[BaseRow]]
+
+    val inputIsAccRetract = StreamExecRetractionRules.isAccRetract(input)
+
+    if (inputIsAccRetract) {
+      throw new TableException(
+        TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+          "Retraction on Over window aggregation is not supported yet. " +
+            "Note: Over window aggregation should not follow a non-windowed 
GroupBy aggregation."))
+    }
+
+    if (!logicWindow.groups.get(0).keys.isEmpty && 
tableConfig.getMinIdleStateRetentionTime < 0) {
+      LOG.warn(
+        "No state retention interval configured for a query which accumulates 
state. " +
+          "Please provide a query configuration with valid retention interval 
to prevent " +
+          "excessive state size. You may specify a retention time of 0 to not 
clean up the state.")
+    }
+
+    val timeType = 
outputRowType.getFieldList.get(orderKey.getFieldIndex).getType
+
+    // check time field
+    if (!FlinkTypeFactory.isRowtimeIndicatorType(timeType)
+      && !FlinkTypeFactory.isProctimeIndicatorType(timeType)) {
+      throw new TableException(
+        "OVER windows' ordering in stream mode must be defined on a time 
attribute.")
+    }
+
+    // identify window rowtime attribute
+    val rowTimeIdx: Option[Int] = if 
(FlinkTypeFactory.isRowtimeIndicatorType(timeType)) {
+      Some(orderKey.getFieldIndex)
+    } else if (FlinkTypeFactory.isProctimeIndicatorType(timeType)) {
+      None
+    } else {
+      throw new TableException(
+        TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+          "OVER windows can only be applied on time attributes."))
+    }
+
+    val config = tableEnv.getConfig
+    val codeGenCtx = CodeGeneratorContext(config)
+    val aggregateCalls = 
logicWindow.groups.get(0).getAggregateCalls(logicWindow).asScala
+    val isRowsClause = overWindow.isRows
+    val constants = logicWindow.constants.asScala
+    val constantTypes = constants.map(c => 
FlinkTypeFactory.toInternalType(c.getType))
+
+    val fieldNames = inputRowType.getFieldNames.asScala
+    val fieldTypes = inputRowType.getFieldList.asScala
+      .map(c => FlinkTypeFactory.toInternalType(c.getType))
+
+    val inRowType = FlinkTypeFactory.toInternalRowType(inputRel.getRowType)
+    val outRowType = FlinkTypeFactory.toInternalRowType(outputRowType)
+
+    val aggInputType = tableEnv.getTypeFactory.buildRelDataType(
+      fieldNames ++ constants.indices.map(i => "TMP" + i),
+      fieldTypes ++ constantTypes)
+
+    val overProcessFunction = if (overWindow.lowerBound.isPreceding
+      && overWindow.lowerBound.isUnbounded
+      && overWindow.upperBound.isCurrentRow) {
+
+      // unbounded OVER window
+      createUnboundedOverProcessFunction(
+        codeGenCtx,
+        aggregateCalls,
+        constants,
+        aggInputType,
+        rowTimeIdx,
+        isRowsClause,
+        tableConfig,
+        tableEnv.getRelBuilder,
+        config.getNullCheck)
+
+    } else if (overWindow.lowerBound.isPreceding
+      && !overWindow.lowerBound.isUnbounded
+      && overWindow.upperBound.isCurrentRow) {
+
+      val boundValue = OverAggregateUtil.getBoundary(logicWindow, 
overWindow.lowerBound)
+
+      if (boundValue.isInstanceOf[BigDecimal]) {
+        throw new TableException(
+          TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+            "the specific value is decimal which haven not supported yet."))
+      }
+      // bounded OVER window
+      val precedingOffset = -1 * boundValue.asInstanceOf[Long] + (if 
(isRowsClause) 1 else 0)
+
+      createBoundedOverProcessFunction(
+        codeGenCtx,
+        aggregateCalls,
+        constants,
+        aggInputType,
+        rowTimeIdx,
+        isRowsClause,
+        precedingOffset,
+        tableConfig,
+        tableEnv.getRelBuilder,
+        config.getNullCheck)
+
+    } else {
+      throw new TableException(
+        TableErrors.INST.sqlOverAggInvalidUseOfOverWindow(
+          "OVER RANGE FOLLOWING windows are not supported yet."))
+    }
+
+    val partitionKeys: Array[Int] = overWindow.keys.toArray
+    val inputTypeInfo = inRowType.toTypeInfo
+
+    val selector = KeySelectorUtil.getBaseRowSelector(partitionKeys, 
inputTypeInfo)
+
+    val returnTypeInfo = outRowType.toTypeInfo
+      .asInstanceOf[BaseRowTypeInfo]
+    // partitioned aggregation
+
+    val operator = new KeyedProcessOperator(overProcessFunction)
+
+    val ret = new OneInputTransformation(
+      inputDS,
+      getOperatorName,
+      operator,
+      returnTypeInfo,
+      inputDS.getParallelism)
+
+    if (partitionKeys.isEmpty) {
+      ret.setParallelism(1)
+      ret.setMaxParallelism(1)
+    }
+
+    // set KeyType and Selector for state
+    ret.setStateKeySelector(selector)
+    ret.setStateKeyType(selector.getProducedType)
+    ret
+  }
+
+  /**
+    * Create an ProcessFunction for unbounded OVER window to evaluate final 
aggregate value.
+    *
+    * @param ctx            code generator context
+    * @param aggregateCalls physical calls to aggregate functions and their 
output field names
+    * @param constants      the constants in aggregates parameters, such as 
sum(1)
+    * @param aggInputType   physical type of the input row which consist of 
input and constants.
+    * @param rowTimeIdx     the index of the rowtime field or None in case of 
processing time.
+    * @param isRowsClause   it is a tag that indicates whether the OVER clause 
is ROWS clause
+    */
+  private def createUnboundedOverProcessFunction(
+      ctx: CodeGeneratorContext,
+      aggregateCalls: Seq[AggregateCall],
+      constants: Seq[RexLiteral],
+      aggInputType: RelDataType,
+      rowTimeIdx: Option[Int],
+      isRowsClause: Boolean,
+      tableConfig: TableConfig,
+      relBuilder: RelBuilder,
+      nullCheck: Boolean): KeyedProcessFunction[BaseRow, BaseRow, BaseRow] = {
+
+    val needRetraction = false
+    val aggInfoList = transformToStreamAggregateInfoList(
+      aggregateCalls,
+      // use aggInputType which considers constants as input instead of 
inputSchema.relDataType
+      aggInputType,
+      Array.fill(aggregateCalls.size)(needRetraction),
+      needInputCount = needRetraction,
+      isStateBackendDataViews = true)
+
+    val fieldTypes = inputRowType.getFieldList.asScala.
+      map(c => FlinkTypeFactory.toInternalType(c.getType)).toArray
+
+    val generator = new AggsHandlerCodeGenerator(
+      ctx,
+      relBuilder,
+      fieldTypes,
+      needRetraction,
+      copyInputField = false)
+
+    val genAggsHandler = generator
+      // over agg code gen must pass the constants
+      .withConstants(constants)
+      .generateAggsHandler("UnboundedOverAggregateHelper", aggInfoList)
+
+    val flattenAccTypes = aggInfoList.getAccTypes.map(
+      TypeConverters.createInternalTypeFromTypeInfo)
+
+    if (rowTimeIdx.isDefined) {
+      if (isRowsClause) {
+        // ROWS unbounded over process function
+        new RowTimeUnboundedRowsOver(
+          genAggsHandler,
+          flattenAccTypes,
+          fieldTypes,
+          rowTimeIdx.get,
+          tableConfig)
+      } else {
+        // RANGE unbounded over process function
+        new RowTimeUnboundedRangeOver(
+          genAggsHandler,
+          flattenAccTypes,
+          fieldTypes,
+          rowTimeIdx.get,
+          tableConfig)
+      }
+    } else {
+      new ProcTimeUnboundedOver(
+        genAggsHandler,
+        flattenAccTypes,
+        tableConfig)
+    }
+  }
+
+  /**
+    * Create an ProcessFunction for ROWS clause bounded OVER window to 
evaluate final
+    * aggregate value.
+    *
+    * @param ctx            code generator context
+    * @param aggregateCalls physical calls to aggregate functions and their 
output field names
+    * @param constants      the constants in aggregates parameters, such as 
sum(1)
+    * @param aggInputType   physical type of the input row which consist of 
input and constants.
+    * @param rowTimeIdx     the index of the rowtime field or None in case of 
processing time.
+    * @param isRowsClause   it is a tag that indicates whether the OVER clause 
is ROWS clause
+    */
+  private def createBoundedOverProcessFunction(
+      ctx: CodeGeneratorContext,
+      aggregateCalls: Seq[AggregateCall],
+      constants: Seq[RexLiteral],
+      aggInputType: RelDataType,
+      rowTimeIdx: Option[Int],
+      isRowsClause: Boolean,
+      precedingOffset: Long,
+      tableConfig: TableConfig,
+      relBuilder: RelBuilder,
+      nullCheck: Boolean): KeyedProcessFunction[BaseRow, BaseRow, BaseRow] = {
+
+    val needRetraction = true
+    val aggInfoList = transformToStreamAggregateInfoList(
+      aggregateCalls,
+      // use aggInputType which considers constants as input instead of 
inputSchema.relDataType
+      aggInputType,
+      Array.fill(aggregateCalls.size)(needRetraction),
+      needInputCount = needRetraction,
+      isStateBackendDataViews = true)
+
+    val fieldTypes = inputRowType.getFieldList.asScala.
+      map(c => FlinkTypeFactory.toInternalType(c.getType)).toArray
+
+    val generator = new AggsHandlerCodeGenerator(
+      ctx,
+      relBuilder,
+      fieldTypes,
+      needRetraction,
+      copyInputField = false)
+
+    val genAggsHandler = generator
+      // over agg code gen must pass the constants
+      .withConstants(constants)
+      .generateAggsHandler("BoundedOverAggregateHelper", aggInfoList)
+
+    val flattenAccTypes = aggInfoList.getAccTypes.map(
+      TypeConverters.createInternalTypeFromTypeInfo)
+
+    if (rowTimeIdx.isDefined) {
+      if (isRowsClause) {
+        new RowTimeBoundedRowsOver(
+          genAggsHandler,
+          flattenAccTypes,
+          fieldTypes,
+          precedingOffset,
+          rowTimeIdx.get,
+          tableConfig)
+      } else {
+        new RowTimeBoundedRangeOver(
+          genAggsHandler,
+          flattenAccTypes,
+          fieldTypes,
+          precedingOffset,
+          rowTimeIdx.get,
+          tableConfig)
+      }
+    } else {
+      if (isRowsClause) {
+        new ProcTimeBoundedRowsOver(
+          genAggsHandler,
+          flattenAccTypes,
+          fieldTypes,
+          precedingOffset,
+          tableConfig)
+      } else {
+        new ProcTimeBoundedRangeOver(
+          genAggsHandler,
+          flattenAccTypes,
+          fieldTypes,
+          precedingOffset,
+          tableConfig)
+      }
+    }
+  }
+
+  private def getOperatorName = {
 
 Review comment:
   The logic of operator name and `explainTerms` is duplicate. We are planning 
to reuse `explainTerms` for operator name. You can simply use `"OverAggregate"` 
as the operator name. So that we can avoid to introduce this fat method and 
changes in `OverAggregateUtil`.
   
   In order to avoid duplicate 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to