[FLINK-5955] [table] Fix aggregations with ObjectReuse enabled by pairwise merging of accumulators.
This closes #3465. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/14fab4c4 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/14fab4c4 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/14fab4c4 Branch: refs/heads/master Commit: 14fab4c412048f769209855d876221817e73ba25 Parents: 2d1721b Author: shaoxuan-wang <[email protected]> Authored: Fri Mar 3 13:50:29 2017 +0800 Committer: Fabian Hueske <[email protected]> Committed: Fri Mar 3 14:27:08 2017 +0100 ---------------------------------------------------------------------- .../table/functions/AggregateFunction.scala | 7 ++- .../AggregateReduceCombineFunction.scala | 51 ++++++++-------- .../AggregateReduceGroupFunction.scala | 52 ++++++++-------- ...ionWindowAggregateCombineGroupFunction.scala | 58 +++++++++--------- ...sionWindowAggregateReduceGroupFunction.scala | 62 +++++++++++--------- ...umbleCountWindowAggReduceGroupFunction.scala | 47 ++++++++------- ...mbleTimeWindowAggReduceCombineFunction.scala | 40 ++++++------- ...TumbleTimeWindowAggReduceGroupFunction.scala | 49 +++++++++------- 8 files changed, 191 insertions(+), 175 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala index 178b439..e5666ce 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala @@ -58,8 +58,11 @@ abstract class AggregateFunction[T] extends UserDefinedFunction { /** * Merge a list of accumulator instances into one accumulator instance. * - * @param accumulators the [[java.util.List]] of accumulators - * that will be merged + * IMPORTANT: You may only return a new accumulator instance or the the first accumulator of the + * input list. If you return another instance, the result of the aggregation function might be + * incorrect. + * + * @param accumulators the [[java.util.List]] of accumulators that will be merged * @return the resulting accumulator */ def merge(accumulators: JList[Accumulator]): Accumulator http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala index 06ac8fb..6b95cb8 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala @@ -19,9 +19,9 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable -import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.CombineFunction +import org.apache.flink.configuration.Configuration import org.apache.flink.table.functions.{Accumulator, AggregateFunction} import org.apache.flink.types.Row @@ -53,6 +53,13 @@ class AggregateReduceCombineFunction( groupingSetsMapping, finalRowArity) with CombineFunction[Row, Row] { + var preAggOutput: Row = _ + + override def open(config: Configuration): Unit = { + super.open(config) + preAggOutput = new Row(aggregates.length + groupKeysMapping.length) + } + /** * For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer, * @@ -62,45 +69,41 @@ class AggregateReduceCombineFunction( */ override def combine(records: Iterable[Row]): Row = { - // merge intermediate aggregate value to buffer. var last: Row = null - accumulatorList.foreach(_.clear()) - val iterator = records.iterator() - var count: Int = 0 + // reset first accumulator in merge list + for (i <- aggregates.indices) { + val accumulator = aggregates(i).createAccumulator() + accumulatorList(i).set(0, accumulator) + } + while (iterator.hasNext) { val record = iterator.next() - count += 1 - // per each aggregator, collect its accumulators to a list + for (i <- aggregates.indices) { - accumulatorList(i).add(record.getField(groupKeysMapping.length + i) - .asInstanceOf[Accumulator]) - } - // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one - // accumulator - if (count > maxMergeLen) { - count = 0 - for (i <- aggregates.indices) { - val agg = aggregates(i) - val accumulator = agg.merge(accumulatorList(i)) - accumulatorList(i).clear() - accumulatorList(i).add(accumulator) - } + // insert received accumulator into acc list + val newAcc = record.getField(groupKeysMapping.length + i).asInstanceOf[Accumulator] + accumulatorList(i).set(1, newAcc) + // merge acc list + val retAcc = aggregates(i).merge(accumulatorList(i)) + // insert result into acc list + accumulatorList(i).set(0, retAcc) } + last = record } + // set the partial merged result to the aggregateBuffer for (i <- aggregates.indices) { - val agg = aggregates(i) - aggregateBuffer.setField(groupKeysMapping.length + i, agg.merge(accumulatorList(i))) + preAggOutput.setField(groupKeysMapping.length + i, accumulatorList(i).get(0)) } // set group keys to aggregateBuffer. for (i <- groupKeysMapping.indices) { - aggregateBuffer.setField(i, last.getField(i)) + preAggOutput.setField(i, last.getField(i)) } - aggregateBuffer + preAggOutput } } http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala index 23b5236..2f75cd7 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala @@ -48,22 +48,27 @@ class AggregateReduceGroupFunction( private val finalRowArity: Int) extends RichGroupReduceFunction[Row, Row] { - protected var aggregateBuffer: Row = _ private var output: Row = _ private var intermediateGroupKeys: Option[Array[Int]] = None - protected val maxMergeLen = 16 - val accumulatorList = Array.fill(aggregates.length) { - new JArrayList[Accumulator]() + + val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) { + new JArrayList[Accumulator](2) } override def open(config: Configuration) { Preconditions.checkNotNull(aggregates) Preconditions.checkNotNull(groupKeysMapping) - aggregateBuffer = new Row(aggregates.length + groupKeysMapping.length) output = new Row(finalRowArity) if (!groupingSetsMapping.isEmpty) { intermediateGroupKeys = Some(groupKeysMapping.map(_._1)) } + + // init lists with two empty accumulators + for (i <- aggregates.indices) { + val accumulator = aggregates(i).createAccumulator() + accumulatorList(i).add(accumulator) + accumulatorList(i).add(accumulator) + } } /** @@ -77,32 +82,28 @@ class AggregateReduceGroupFunction( */ override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { - // merge intermediate aggregate value to buffer. var last: Row = null - accumulatorList.foreach(_.clear()) - val iterator = records.iterator() - var count: Int = 0 + // reset first accumulator in merge list + for (i <- aggregates.indices) { + val accumulator = aggregates(i).createAccumulator() + accumulatorList(i).set(0, accumulator) + } + while (iterator.hasNext) { val record = iterator.next() - count += 1 - // per each aggregator, collect its accumulators to a list + for (i <- aggregates.indices) { - accumulatorList(i).add(record.getField(groupKeysMapping.length + i) - .asInstanceOf[Accumulator]) - } - // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one - // accumulator - if (count > maxMergeLen) { - count = 0 - for (i <- aggregates.indices) { - val agg = aggregates(i) - val accumulator = agg.merge(accumulatorList(i)) - accumulatorList(i).clear() - accumulatorList(i).add(accumulator) - } + // insert received accumulator into acc list + val newAcc = record.getField(groupKeysMapping.length + i).asInstanceOf[Accumulator] + accumulatorList(i).set(1, newAcc) + // merge acc list + val retAcc = aggregates(i).merge(accumulatorList(i)) + // insert result into acc list + accumulatorList(i).set(0, retAcc) } + last = record } @@ -116,8 +117,7 @@ class AggregateReduceGroupFunction( aggregateMapping.foreach { case (after, previous) => { val agg = aggregates(previous) - val accumulator = agg.merge(accumulatorList(previous)) - val result = agg.getValue(accumulator) + val result = agg.getValue(accumulatorList(previous).get(0)) output.setField(after, result) } } http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala index 47fa0f1..88cd19f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala @@ -45,17 +45,24 @@ class DataSetSessionWindowAggregateCombineGroupFunction( extends RichGroupCombineFunction[Row, Row] with ResultTypeQueryable[Row] { private var aggregateBuffer: Row = _ - private var accumStartPos: Int = groupingKeys.length - private var rowTimeFieldPos = accumStartPos + aggregates.length - private val maxMergeLen = 16 - val accumulatorList = Array.fill(aggregates.length) { - new JArrayList[Accumulator]() + private val accumStartPos: Int = groupingKeys.length + private val rowTimeFieldPos = accumStartPos + aggregates.length + + val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) { + new JArrayList[Accumulator](2) } override def open(config: Configuration) { Preconditions.checkNotNull(aggregates) Preconditions.checkNotNull(groupingKeys) aggregateBuffer = new Row(rowTimeFieldPos + 2) + + // init lists with two empty accumulators + for (i <- aggregates.indices) { + val accumulator = aggregates(i).createAccumulator() + accumulatorList(i).add(accumulator) + accumulatorList(i).add(accumulator) + } } /** @@ -72,15 +79,17 @@ class DataSetSessionWindowAggregateCombineGroupFunction( var windowStart: java.lang.Long = null var windowEnd: java.lang.Long = null var currentRowTime: java.lang.Long = null - accumulatorList.foreach(_.clear()) - val iterator = records.iterator() + // reset first accumulator in merge list + for (i <- aggregates.indices) { + val accumulator = aggregates(i).createAccumulator() + accumulatorList(i).set(0, accumulator) + } + val iterator = records.iterator() - var count: Int = 0 while (iterator.hasNext) { val record = iterator.next() - count += 1 currentRowTime = record.getField(rowTimeFieldPos).asInstanceOf[Long] // initial traversal or opening a new window if (windowEnd == null || (windowEnd != null && (currentRowTime > windowEnd))) { @@ -90,9 +99,11 @@ class DataSetSessionWindowAggregateCombineGroupFunction( // emit the current window's merged data doCollect(out, accumulatorList, windowStart, windowEnd) - // clear the accumulator list for all aggregate - accumulatorList.foreach(_.clear()) - count = 0 + // reset first value of accumulator list + for (i <- aggregates.indices) { + val accumulator = aggregates(i).createAccumulator() + accumulatorList(i).set(0, accumulator) + } } else { // set group keys to aggregateBuffer. for (i <- groupingKeys.indices) { @@ -103,21 +114,14 @@ class DataSetSessionWindowAggregateCombineGroupFunction( windowStart = record.getField(rowTimeFieldPos).asInstanceOf[Long] } - // collect the accumulators for each aggregate for (i <- aggregates.indices) { - accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator]) - } - - // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one - // accumulator - if (count > maxMergeLen) { - count = 0 - for (i <- aggregates.indices) { - val agg = aggregates(i) - val accumulator = agg.merge(accumulatorList(i)) - accumulatorList(i).clear() - accumulatorList(i).add(accumulator) - } + // insert received accumulator into acc list + val newAcc = record.getField(accumStartPos + i).asInstanceOf[Accumulator] + accumulatorList(i).set(1, newAcc) + // merge acc list + val retAcc = aggregates(i).merge(accumulatorList(i)) + // insert result into acc list + accumulatorList(i).set(0, retAcc) } // the current rowtime is the last rowtime of the next calculation. @@ -146,7 +150,7 @@ class DataSetSessionWindowAggregateCombineGroupFunction( // merge the accumulators into one accumulator for (i <- aggregates.indices) { - aggregateBuffer.setField(accumStartPos + i, aggregates(i).merge(accumulatorList(i))) + aggregateBuffer.setField(accumStartPos + i, accumulatorList(i).get(0)) } // intermediate Row WindowStartPos is rowtime pos. http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala index 1570671..ebef211 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala @@ -64,13 +64,13 @@ class DataSetSessionWindowAggregateReduceGroupFunction( private var aggregateBuffer: Row = _ private var output: Row = _ private var collector: TimeWindowPropertyCollector = _ - private var accumStartPos: Int = groupKeysMapping.length - private var intermediateRowArity: Int = accumStartPos + aggregates.length + 2 - private var intermediateRowWindowStartPos = intermediateRowArity - 2 - private var intermediateRowWindowEndPos = intermediateRowArity - 1 - private val maxMergeLen = 16 - val accumulatorList = Array.fill(aggregates.length) { - new JArrayList[Accumulator]() + private val accumStartPos: Int = groupKeysMapping.length + private val intermediateRowArity: Int = accumStartPos + aggregates.length + 2 + private val intermediateRowWindowStartPos = intermediateRowArity - 2 + private val intermediateRowWindowEndPos = intermediateRowArity - 1 + + val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) { + new JArrayList[Accumulator](2) } override def open(config: Configuration) { @@ -79,6 +79,13 @@ class DataSetSessionWindowAggregateReduceGroupFunction( aggregateBuffer = new Row(intermediateRowArity) output = new Row(finalRowArity) collector = new TimeWindowPropertyCollector(finalRowWindowStartPos, finalRowWindowEndPos) + + // init lists with two empty accumulators + for (i <- aggregates.indices) { + val accumulator = aggregates(i).createAccumulator() + accumulatorList(i).add(accumulator) + accumulatorList(i).add(accumulator) + } } /** @@ -96,14 +103,17 @@ class DataSetSessionWindowAggregateReduceGroupFunction( var windowStart: java.lang.Long = null var windowEnd: java.lang.Long = null var currentRowTime: java.lang.Long = null - accumulatorList.foreach(_.clear()) + + // reset first accumulator in merge list + for (i <- aggregates.indices) { + val accumulator = aggregates(i).createAccumulator() + accumulatorList(i).set(0, accumulator) + } val iterator = records.iterator() - var count: Int = 0 while (iterator.hasNext) { val record = iterator.next() - count += 1 currentRowTime = record.getField(intermediateRowWindowStartPos).asInstanceOf[Long] // initial traversal or opening a new window if (null == windowEnd || @@ -114,9 +124,11 @@ class DataSetSessionWindowAggregateReduceGroupFunction( // evaluate and emit the current window's result. doEvaluateAndCollect(out, accumulatorList, windowStart, windowEnd) - // clear the accumulator list for all aggregate - accumulatorList.foreach(_.clear()) - count = 0 + // reset first accumulator in list + for (i <- aggregates.indices) { + val accumulator = aggregates(i).createAccumulator() + accumulatorList(i).set(0, accumulator) + } } else { // set group keys value to final output. groupKeysMapping.foreach { @@ -128,21 +140,14 @@ class DataSetSessionWindowAggregateReduceGroupFunction( windowStart = record.getField(intermediateRowWindowStartPos).asInstanceOf[Long] } - // collect the accumulators for each aggregate for (i <- aggregates.indices) { - accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator]) - } - - // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one - // accumulator - if (count > maxMergeLen) { - count = 0 - for (i <- aggregates.indices) { - val agg = aggregates(i) - val accumulator = agg.merge(accumulatorList(i)) - accumulatorList(i).clear() - accumulatorList(i).add(accumulator) - } + // insert received accumulator into acc list + val newAcc = record.getField(accumStartPos + i).asInstanceOf[Accumulator] + accumulatorList(i).set(1, newAcc) + // merge acc list + val retAcc = aggregates(i).merge(accumulatorList(i)) + // insert result into acc list + accumulatorList(i).set(0, retAcc) } windowEnd = if (isInputCombined) { @@ -178,8 +183,7 @@ class DataSetSessionWindowAggregateReduceGroupFunction( aggregateMapping.foreach { case (after, previous) => val agg = aggregates(previous) - val accum = agg.merge(accumulatorList(previous)) - output.setField(after, agg.getValue(accum)) + output.setField(after, agg.getValue(accumulatorList(previous).get(0))) } // adds TimeWindow properties to output then emit output http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala index b722330..85df1d8 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala @@ -51,9 +51,9 @@ class DataSetTumbleCountWindowAggReduceGroupFunction( private var output: Row = _ private val accumStartPos: Int = groupKeysMapping.length private val intermediateRowArity: Int = accumStartPos + aggregates.length + 1 - private val maxMergeLen = 16 - val accumulatorList = Array.fill(aggregates.length) { - new JArrayList[Accumulator]() + + val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) { + new JArrayList[Accumulator](2) } override def open(config: Configuration) { @@ -61,37 +61,41 @@ class DataSetTumbleCountWindowAggReduceGroupFunction( Preconditions.checkNotNull(groupKeysMapping) aggregateBuffer = new Row(intermediateRowArity) output = new Row(finalRowArity) + + // init lists with two empty accumulators + for (i <- aggregates.indices) { + val accumulator = aggregates(i).createAccumulator() + accumulatorList(i).add(accumulator) + accumulatorList(i).add(accumulator) + } } override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { var count: Long = 0 - accumulatorList.foreach(_.clear()) - val iterator = records.iterator() while (iterator.hasNext) { - val record = iterator.next() if (count == 0) { - // clear the accumulator list for all aggregate - accumulatorList.foreach(_.clear()) + // reset first accumulator + for (i <- aggregates.indices) { + val accumulator = aggregates(i).createAccumulator() + accumulatorList(i).set(0, accumulator) + } } - // collect the accumulators for each aggregate - for (i <- aggregates.indices) { - accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator]) - } + val record = iterator.next() count += 1 - // for every maxMergeLen accumulators, we merge them into one - if (count % maxMergeLen == 0) { - for (i <- aggregates.indices) { - val agg = aggregates(i) - val accumulator = agg.merge(accumulatorList(i)) - accumulatorList(i).clear() - accumulatorList(i).add(accumulator) - } + for (i <- aggregates.indices) { + // insert received accumulator into acc list + val newAcc = record.getField(accumStartPos + i).asInstanceOf[Accumulator] + accumulatorList(i).set(1, newAcc) + // merge acc list + val retAcc = aggregates(i).merge(accumulatorList(i)) + // insert result into acc list + accumulatorList(i).set(0, retAcc) } if (windowSize == count) { @@ -105,8 +109,7 @@ class DataSetTumbleCountWindowAggReduceGroupFunction( aggregateMapping.foreach { case (after, previous) => val agg = aggregates(previous) - val accumulator = agg.merge(accumulatorList(previous)) - output.setField(after, agg.getValue(accumulator)) + output.setField(after, agg.getValue(accumulatorList(previous).get(0))) } // emit the output http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala index d507a58..df8bed9 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala @@ -18,7 +18,6 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable -import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.CombineFunction import org.apache.flink.table.functions.{Accumulator, AggregateFunction} @@ -68,38 +67,33 @@ class DataSetTumbleTimeWindowAggReduceCombineFunction( override def combine(records: Iterable[Row]): Row = { var last: Row = null - accumulatorList.foreach(_.clear()) - val iterator = records.iterator() - var count: Int = 0 + // reset first accumulator in merge list + for (i <- aggregates.indices) { + val accumulator = aggregates(i).createAccumulator() + accumulatorList(i).set(0, accumulator) + } + while (iterator.hasNext) { val record = iterator.next() - count += 1 - // per each aggregator, collect its accumulators to a list + for (i <- aggregates.indices) { - accumulatorList(i).add(record.getField(groupKeysMapping.length + i) - .asInstanceOf[Accumulator]) - } - // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one - // accumulator - if (count > maxMergeLen) { - count = 0 - for (i <- aggregates.indices) { - val agg = aggregates(i) - val accumulator = agg.merge(accumulatorList(i)) - accumulatorList(i).clear() - accumulatorList(i).add(accumulator) - } + // insert received accumulator into acc list + val newAcc = record.getField(groupKeysMapping.length + i).asInstanceOf[Accumulator] + accumulatorList(i).set(1, newAcc) + // merge acc list + val retAcc = aggregates(i).merge(accumulatorList(i)) + // insert result into acc list + accumulatorList(i).set(0, retAcc) } + last = record } - // per each aggregator, merge list of accumulators into one and save the result to the - // intermediate aggregate buffer + // set the partial merged result to the aggregateBuffer for (i <- aggregates.indices) { - val agg = aggregates(i) - aggregateBuffer.setField(groupKeysMapping.length + i, agg.merge(accumulatorList(i))) + aggregateBuffer.setField(groupKeysMapping.length + i, accumulatorList(i).get(0)) } // set group keys to aggregateBuffer. http://git-wip-us.apache.org/repos/asf/flink/blob/14fab4c4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala index 63d2aeb..7ce0bf1 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala @@ -57,9 +57,10 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction( private val accumStartPos: Int = groupKeysMapping.length private val rowtimePos: Int = accumStartPos + aggregates.length private val intermediateRowArity: Int = rowtimePos + 1 - protected val maxMergeLen = 16 - val accumulatorList = Array.fill(aggregates.length) { - new JArrayList[Accumulator]() + + + val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) { + new JArrayList[Accumulator](2) } override def open(config: Configuration) { @@ -68,34 +69,39 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction( aggregateBuffer = new Row(intermediateRowArity) output = new Row(finalRowArity) collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos) + + // init lists with two empty accumulators + for (i <- aggregates.indices) { + val accumulator = aggregates(i).createAccumulator() + accumulatorList(i).add(accumulator) + accumulatorList(i).add(accumulator) + } } override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { var last: Row = null - accumulatorList.foreach(_.clear()) - val iterator = records.iterator() - var count: Int = 0 + // reset first accumulator in merge list + for (i <- aggregates.indices) { + val accumulator = aggregates(i).createAccumulator() + accumulatorList(i).set(0, accumulator) + } + while (iterator.hasNext) { val record = iterator.next() - count += 1 - // per each aggregator, collect its accumulators to a list + for (i <- aggregates.indices) { - accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator]) - } - // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one - // accumulator - if (count > maxMergeLen) { - count = 0 - for (i <- aggregates.indices) { - val agg = aggregates(i) - val accumulator = agg.merge(accumulatorList(i)) - accumulatorList(i).clear() - accumulatorList(i).add(accumulator) - } + // insert received accumulator into acc list + val newAcc = record.getField(groupKeysMapping.length + i).asInstanceOf[Accumulator] + accumulatorList(i).set(1, newAcc) + // merge acc list + val retAcc = aggregates(i).merge(accumulatorList(i)) + // insert result into acc list + accumulatorList(i).set(0, retAcc) } + last = record } @@ -109,8 +115,7 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction( aggregateMapping.foreach { case (after, previous) => { val agg = aggregates(previous) - val accumulator = agg.merge(accumulatorList(previous)) - val result = agg.getValue(accumulator) + val result = agg.getValue(accumulatorList(previous).get(0)) output.setField(after, result) } }
