Github user fhueske commented on a diff in the pull request:
https://github.com/apache/flink/pull/3735#discussion_r111995303
--- Diff:
flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala
---
@@ -19,88 +19,71 @@
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
-import java.util.{ArrayList => JArrayList}
import org.apache.flink.api.common.functions.RichGroupReduceFunction
import org.apache.flink.configuration.Configuration
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.codegen.{Compiler,
GeneratedAggregationsFunction}
import org.apache.flink.types.Row
import org.apache.flink.util.{Collector, Preconditions}
+import org.slf4j.LoggerFactory
/**
* [[RichGroupReduceFunction]] to compute the final result of a
pre-aggregated aggregation
* for batch (DataSet) queries.
*
- * @param aggregates The aggregate functions.
- * @param aggOutFields The positions of the aggregation results in the
output
+ * @param genAggregations Code-generated [[GeneratedAggregations]]
* @param gkeyOutFields The positions of the grouping keys in the output
* @param groupingSetsMapping The mapping of grouping set keys between
input and output positions.
- * @param finalRowArity The arity of the final resulting row
*/
class DataSetFinalAggFunction(
- private val aggregates: Array[AggregateFunction[_ <: Any]],
- private val aggOutFields: Array[Int],
+ private val genAggregations: GeneratedAggregationsFunction,
private val gkeyOutFields: Array[Int],
- private val groupingSetsMapping: Array[(Int, Int)],
- private val finalRowArity: Int)
- extends RichGroupReduceFunction[Row, Row] {
+ private val groupingSetsMapping: Array[(Int, Int)])
+ extends RichGroupReduceFunction[Row, Row]
+ with Compiler[GeneratedAggregations] {
- Preconditions.checkNotNull(aggregates)
- Preconditions.checkNotNull(aggOutFields)
Preconditions.checkNotNull(gkeyOutFields)
Preconditions.checkNotNull(groupingSetsMapping)
private var output: Row = _
+ private var accumulators: Row = _
+
+ val LOG = LoggerFactory.getLogger(this.getClass)
+ private var function: GeneratedAggregations = _
private val intermediateGKeys: Option[Array[Int]] = if
(!groupingSetsMapping.isEmpty) {
Some(gkeyOutFields)
} else {
None
}
- private val numAggs = aggregates.length
- private val numGKeys = gkeyOutFields.length
-
- private val accumulators: Array[JArrayList[Accumulator]] =
- Array.fill(numAggs)(new JArrayList[Accumulator](2))
-
override def open(config: Configuration) {
- output = new Row(finalRowArity)
-
- // init lists with two empty accumulators
- for (i <- aggregates.indices) {
- val accumulator = aggregates(i).createAccumulator()
- accumulators(i).add(accumulator)
- accumulators(i).add(accumulator)
- }
+ LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+ s"Code:\n$genAggregations.code")
+ val clazz = compile(
+ getClass.getClassLoader,
+ genAggregations.name,
+ genAggregations.code)
+ LOG.debug("Instantiating AggregateHelper.")
+ function = clazz.newInstance()
+
+ output = function.createOutputRow()
+ accumulators = function.createAccumulators()
}
override def reduce(records: Iterable[Row], out: Collector[Row]): Unit =
{
val iterator = records.iterator()
// reset first accumulator
- var i = 0
- while (i < aggregates.length) {
- aggregates(i).resetAccumulator(accumulators(i).get(0))
- i += 1
- }
+ function.resetAccumulator(accumulators)
+ var i = 0
while (iterator.hasNext) {
val record = iterator.next()
--- End diff --
we can make `record` a `var` and move its definition outside of the loop.
Then we can get rid of the `if (!iterator.hasNext)` check in the body of
the while loop and set the `output` fields after the loop has terminated.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---