Github user fhueske commented on a diff in the pull request:

    https://github.com/apache/flink/pull/3423#discussion_r103801735
  
    --- Diff: 
flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
 ---
    @@ -363,199 +342,112 @@ object AggregateUtil {
               groupingOffsetMapping,
               aggOffsetMapping,
               groupingSetsMapping,
    -          intermediateRowArity,
               outputType.getFieldCount)
           }
         groupReduceFunction
       }
     
       /**
    -    * Create a [[org.apache.flink.api.common.functions.ReduceFunction]] 
for incremental window
    -    * aggregation.
    -    *
    +    * Create an [[AllWindowFunction]] for non-partitioned window 
aggregates.
         */
    -  private[flink] def createIncrementalAggregateReduceFunction(
    -      namedAggregates: Seq[CalcitePair[AggregateCall, String]],
    -      inputType: RelDataType,
    -      outputType: RelDataType,
    -      groupings: Array[Int])
    -    : IncrementalAggregateReduceFunction = {
    -
    -    val aggregates = transformToAggregateFunctions(
    -      namedAggregates.map(_.getKey),inputType,groupings.length)._2
    -
    -    val groupingOffsetMapping =
    -      getGroupingOffsetAndAggOffsetMapping(
    -        namedAggregates,
    -        inputType,
    -        outputType,
    -        groupings)._1
    -
    -    val intermediateRowArity = groupings.length + 
aggregates.map(_.intermediateDataType.length).sum
    -    val reduceFunction = new IncrementalAggregateReduceFunction(
    -      aggregates,
    -      groupingOffsetMapping,
    -      intermediateRowArity)
    -    reduceFunction
    -  }
    -
    -  /**
    -    * Create an [[AllWindowFunction]] to compute non-partitioned group 
window aggregates.
    -    */
    -  private[flink] def createAllWindowAggregationFunction(
    +  private[flink] def createAggregationAllWindowFunction(
           window: LogicalWindow,
    -      namedAggregates: Seq[CalcitePair[AggregateCall, String]],
    -      inputType: RelDataType,
    -      outputType: RelDataType,
    -      groupings: Array[Int],
    -      properties: Seq[NamedWindowProperty])
    -    : AllWindowFunction[Row, Row, DataStreamWindow] = {
    -
    -    val aggFunction =
    -      createAggregateGroupReduceFunction(
    -        namedAggregates,
    -        inputType,
    -        outputType,
    -        groupings,
    -        inGroupingSet = false)
    +      finalRowArity: Int,
    +      properties: Seq[NamedWindowProperty]): AllWindowFunction[Row, Row, 
DataStreamWindow] = {
     
         if (isTimeWindow(window)) {
           val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
    -      new AggregateAllTimeWindowFunction(aggFunction, startPos, endPos)
    -      .asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]]
    +      new IncrementalAggregateAllTimeWindowFunction(
    +        startPos,
    +        endPos,
    +        finalRowArity)
    +        .asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]]
         } else {
    -      new AggregateAllWindowFunction(aggFunction)
    +      new IncrementalAggregateAllWindowFunction(
    +        finalRowArity)
         }
       }
     
       /**
    -    * Create a [[WindowFunction]] to compute partitioned group window 
aggregates.
    -    *
    +    * Create a [[WindowFunction]] for group window aggregates.
         */
    -  private[flink] def createWindowAggregationFunction(
    +  private[flink] def createAggregationGroupWindowFunction(
           window: LogicalWindow,
    -      namedAggregates: Seq[CalcitePair[AggregateCall, String]],
    -      inputType: RelDataType,
    -      outputType: RelDataType,
    -      groupings: Array[Int],
    -      properties: Seq[NamedWindowProperty])
    -    : WindowFunction[Row, Row, Tuple, DataStreamWindow] = {
    -
    -    val aggFunction =
    -      createAggregateGroupReduceFunction(
    -        namedAggregates,
    -        inputType,
    -        outputType,
    -        groupings,
    -        inGroupingSet = false)
    +      finalRowArity: Int,
    +      properties: Seq[NamedWindowProperty]): WindowFunction[Row, Row, 
Tuple, DataStreamWindow] = {
     
         if (isTimeWindow(window)) {
           val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
    -      new AggregateTimeWindowFunction(aggFunction, startPos, endPos)
    -      .asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]]
    +      new IncrementalAggregateTimeWindowFunction(
    +        startPos,
    +        endPos,
    +        finalRowArity)
    +        .asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]]
         } else {
    -      new AggregateWindowFunction(aggFunction)
    +      new IncrementalAggregateWindowFunction(
    +        finalRowArity)
         }
       }
     
    -  /**
    -    * Create an [[AllWindowFunction]] to finalize incrementally 
pre-computed non-partitioned
    -    * window aggregates.
    -    */
    -  private[flink] def createAllWindowIncrementalAggregationFunction(
    -      window: LogicalWindow,
    +  private[flink] def createDataStreamAggregateFunction(
           namedAggregates: Seq[CalcitePair[AggregateCall, String]],
           inputType: RelDataType,
           outputType: RelDataType,
    -      groupings: Array[Int],
    -      properties: Seq[NamedWindowProperty])
    -    : AllWindowFunction[Row, Row, DataStreamWindow] = {
    +      groupKeysIndex: Array[Int]): (ApiAggregateFunction[Row, Row, Row], 
RowTypeInfo) = {
     
    -    val aggregates = transformToAggregateFunctions(
    -      namedAggregates.map(_.getKey),inputType,groupings.length)._2
    +    val (aggFields, aggregates) =
    +      transformToAggregateFunctions(namedAggregates.map(_.getKey), 
inputType, groupKeysIndex.length)
     
    -    val (groupingOffsetMapping, aggOffsetMapping) =
    -      getGroupingOffsetAndAggOffsetMapping(
    -      namedAggregates,
    -      inputType,
    -      outputType,
    -      groupings)
    +    val groupKeysMapping = getGroupKeysMapping(inputType, outputType, 
groupKeysIndex)
     
    -    val finalRowArity = outputType.getFieldCount
    +    val aggregateMapping = getAggregateMapping(namedAggregates, outputType)
     
    -    if (isTimeWindow(window)) {
    -      val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
    -      new IncrementalAggregateAllTimeWindowFunction(
    -        aggregates,
    -        groupingOffsetMapping,
    -        aggOffsetMapping,
    -        finalRowArity,
    -        startPos,
    -        endPos)
    -      .asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]]
    -    } else {
    -      new IncrementalAggregateAllWindowFunction(
    -        aggregates,
    -        groupingOffsetMapping,
    -        aggOffsetMapping,
    -        finalRowArity)
    +    if (groupKeysMapping.length != groupKeysIndex.length ||
    +      aggregateMapping.length != namedAggregates.length) {
    +      throw new TableException(
    +        "Could not find output field in input data type or aggregate 
functions.")
         }
    +
    +    val accumulatorRowType = createAccumulatorRowType(inputType, 
groupKeysIndex, aggregates)
    +    val aggFunction = new AggregateAggFunction(
    +      aggregates,
    +      aggFields,
    +      aggregateMapping,
    +      groupKeysIndex,
    +      groupKeysMapping,
    +      outputType.getFieldCount)
    +
    +    (aggFunction, accumulatorRowType)
       }
     
       /**
    -    * Create a [[WindowFunction]] to finalize incrementally pre-computed 
window aggregates.
    +    * Return true if all aggregates can be partially merged. False 
otherwise.
         */
    -  private[flink] def createWindowIncrementalAggregationFunction(
    -      window: LogicalWindow,
    -      namedAggregates: Seq[CalcitePair[AggregateCall, String]],
    +  private[flink] def doAllSupportPartialMerge(
    +      aggregateCalls: Seq[AggregateCall],
           inputType: RelDataType,
    -      outputType: RelDataType,
    -      groupings: Array[Int],
    -      properties: Seq[NamedWindowProperty])
    -    : WindowFunction[Row, Row, Tuple, DataStreamWindow] = {
    +      groupKeysCount: Int): Boolean = {
     
    -    val aggregates = transformToAggregateFunctions(
    -      namedAggregates.map(_.getKey),inputType,groupings.length)._2
    -
    -    val (groupingOffsetMapping, aggOffsetMapping) =
    -      getGroupingOffsetAndAggOffsetMapping(
    -        namedAggregates,
    -        inputType,
    -        outputType,
    -        groupings)
    -
    -    val finalRowArity = outputType.getFieldCount
    +    val aggregateList = transformToAggregateFunctions(
    +      aggregateCalls,
    +      inputType,
    +      groupKeysCount)._2
     
    -    if (isTimeWindow(window)) {
    -      val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
    -      new IncrementalAggregateTimeWindowFunction(
    -        aggregates,
    -        groupingOffsetMapping,
    -        aggOffsetMapping,
    -        finalRowArity,
    -        startPos,
    -        endPos)
    -      .asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]]
    -    } else {
    -      new IncrementalAggregateWindowFunction(
    -        aggregates,
    -        groupingOffsetMapping,
    -        aggOffsetMapping,
    -        finalRowArity)
    -    }
    +    doAllSupportPartialMerge(aggregateList)
       }
     
       /**
    -    * Return true if all aggregates can be partially computed. False 
otherwise.
    +    * Return true if all aggregates can be partially merged. False 
otherwise.
         */
    -  private[flink] def doAllSupportPartialAggregation(
    -    aggregateCalls: Seq[AggregateCall],
    -    inputType: RelDataType,
    -    groupKeysCount: Int): Boolean = {
    -    transformToAggregateFunctions(
    -      aggregateCalls,
    -      inputType,
    -      groupKeysCount)._2.forall(_.supportPartial)
    +  private[flink] def doAllSupportPartialMerge(
    +      aggregateList: Array[TableAggregateFunction[_ <: Any]]): Boolean = {
    +    var ret: Boolean = true
    --- End diff --
    
    can be simplified to 
    ```
    aggregateList.forall(ifMethodExitInFunction("merge", _))
    ```


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

Reply via email to