Github user fhueske commented on a diff in the pull request:

    https://github.com/apache/flink/pull/3735#discussion_r112003763
  
    --- Diff: 
flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala
 ---
    @@ -25,58 +25,56 @@ import 
org.apache.flink.api.common.typeinfo.TypeInformation
     import org.apache.flink.api.java.typeutils.ResultTypeQueryable
     import org.apache.flink.configuration.Configuration
     import org.apache.flink.streaming.api.windowing.windows.TimeWindow
    -import org.apache.flink.table.functions.AggregateFunction
    +import org.apache.flink.table.codegen.{Compiler, 
GeneratedAggregationsFunction}
     import org.apache.flink.types.Row
    -import org.apache.flink.util.Preconditions
    -
    +import org.slf4j.LoggerFactory
     
     /**
       * This map function only works for windows on batch tables.
       * It appends an (aligned) rowtime field to the end of the output row.
    +  *
    +  * @param genAggregations      Code-generated [[GeneratedAggregations]]
    +  * @param timeFieldPos         Time field position in input row
    +  * @param tumbleTimeWindowSize The size of tumble time window
       */
     class DataSetWindowAggMapFunction(
    -    private val aggregates: Array[AggregateFunction[_]],
    -    private val aggFields: Array[Array[Int]],
    -    private val groupingKeys: Array[Int],
    -    private val timeFieldPos: Int, // time field position in input row
    +    private val genAggregations: GeneratedAggregationsFunction,
    +    private val timeFieldPos: Int,
         private val tumbleTimeWindowSize: Option[Long],
         @transient private val returnType: TypeInformation[Row])
    -  extends RichMapFunction[Row, Row] with ResultTypeQueryable[Row] {
    -
    -  Preconditions.checkNotNull(aggregates)
    -  Preconditions.checkNotNull(aggFields)
    -  Preconditions.checkArgument(aggregates.length == aggFields.length)
    +  extends RichMapFunction[Row, Row]
    +    with ResultTypeQueryable[Row]
    +    with Compiler[GeneratedAggregations] {
     
       private var output: Row = _
    -  // add one more arity to store rowtime
    -  private val partialRowLength = groupingKeys.length + aggregates.length + 
1
    -  // rowtime index in the buffer output row
    -  private val rowtimeIndex: Int = partialRowLength - 1
    +
    +  val LOG = LoggerFactory.getLogger(this.getClass)
    +  private var function: GeneratedAggregations = _
     
       override def open(config: Configuration) {
    -    output = new Row(partialRowLength)
    +    LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
    +                s"Code:\n$genAggregations.code")
    +    val clazz = compile(
    +      getRuntimeContext.getUserCodeClassLoader,
    +      genAggregations.name,
    +      genAggregations.code)
    +    LOG.debug("Instantiating AggregateHelper.")
    +    function = clazz.newInstance()
    +
    +    output = function.createOutputRow()
       }
     
       override def map(input: Row): Row = {
     
    -    var i = 0
    -    while (i < aggregates.length) {
    -      val agg = aggregates(i)
    -      val fieldValue = input.getField(aggFields(i)(0))
    -      val accumulator = agg.createAccumulator()
    -      agg.accumulate(accumulator, fieldValue)
    -      output.setField(groupingKeys.length + i, accumulator)
    -      i += 1
    -    }
    +    function.createAccumulatorsAndSetToOutput(output)
    --- End diff --
    
    create an accumulator with `function.createAccumulator()` once in `open()`, 
reset it here, and copy it to `output` with `function.setAggregationResults()`?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

Reply via email to