[FLINK-4937] [table] Add incremental group window aggregation for streaming Table API.
This closes #2792. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/74e0971a Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/74e0971a Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/74e0971a Branch: refs/heads/master Commit: 74e0971a5511c511feb94bee7a0ce39eb9951b62 Parents: 06d252e Author: Jincheng Sun <[email protected]> Authored: Sat Nov 12 18:55:07 2016 +0800 Committer: Fabian Hueske <[email protected]> Committed: Wed Nov 23 18:35:44 2016 +0100 ---------------------------------------------------------------------- .../plan/nodes/dataset/DataSetAggregate.scala | 18 +- .../nodes/datastream/DataStreamAggregate.scala | 228 ++++++----- .../AggregateAllTimeWindowFunction.scala | 7 +- .../aggregate/AggregateAllWindowFunction.scala | 7 +- .../aggregate/AggregateMapFunction.scala | 3 +- .../AggregateReduceCombineFunction.scala | 54 +-- .../AggregateReduceGroupFunction.scala | 4 +- .../aggregate/AggregateTimeWindowFunction.scala | 14 +- .../table/runtime/aggregate/AggregateUtil.scala | 392 ++++++++++++++++--- .../aggregate/AggregateWindowFunction.scala | 12 +- ...rementalAggregateAllTimeWindowFunction.scala | 68 ++++ .../IncrementalAggregateAllWindowFunction.scala | 79 ++++ .../IncrementalAggregateReduceFunction.scala | 63 +++ ...IncrementalAggregateTimeWindowFunction.scala | 69 ++++ .../IncrementalAggregateWindowFunction.scala | 81 ++++ .../scala/stream/table/AggregationsITCase.scala | 14 +- 16 files changed, 848 insertions(+), 265 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala index c73d781..e85ade0 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala @@ -85,17 +85,22 @@ class DataSetAggregate( } override def translateToPlan( - tableEnv: BatchTableEnvironment, - expectedType: Option[TypeInformation[Any]]): DataSet[Any] = { + tableEnv: BatchTableEnvironment, + expectedType: Option[TypeInformation[Any]]): DataSet[Any] = { val config = tableEnv.getConfig val groupingKeys = grouping.indices.toArray - // add grouping fields, position keys in the input, and input type - val aggregateResult = AggregateUtil.createOperatorFunctionsForAggregates( + + val mapFunction = AggregateUtil.createPrepareMapFunction( + namedAggregates, + grouping, + inputType) + + val groupReduceFunction = AggregateUtil.createAggregateGroupReduceFunction( namedAggregates, inputType, - getRowType, + rowRelDataType, grouping) val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan( @@ -111,10 +116,9 @@ class DataSetAggregate( val aggString = aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil) val prepareOpName = s"prepare select: ($aggString)" val mappedInput = inputDS - .map(aggregateResult._1) + .map(mapFunction) .name(prepareOpName) - val groupReduceFunction = aggregateResult._2 val rowTypeInfo = new RowTypeInfo(fieldTypes) val result = { http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamAggregate.scala index b4ae3ab..c7d5131 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamAggregate.scala @@ -22,7 +22,6 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} -import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.tuple.Tuple import org.apache.flink.api.table.FlinkRelBuilder.NamedWindowProperty @@ -31,12 +30,11 @@ import org.apache.flink.api.table.plan.logical._ import org.apache.flink.api.table.plan.nodes.FlinkAggregate import org.apache.flink.api.table.plan.nodes.datastream.DataStreamAggregate._ import org.apache.flink.api.table.runtime.aggregate.AggregateUtil._ -import org.apache.flink.api.table.runtime.aggregate._ +import org.apache.flink.api.table.runtime.aggregate.{Aggregate, _} import org.apache.flink.api.table.typeutils.TypeCheckUtils.isTimeInterval import org.apache.flink.api.table.typeutils.{RowIntervalTypeInfo, RowTypeInfo, TimeIntervalTypeInfo, TypeConverter} -import org.apache.flink.api.table.{TableException, FlinkTypeFactory, Row, StreamTableEnvironment} +import org.apache.flink.api.table.{FlinkTypeFactory, Row, StreamTableEnvironment} import org.apache.flink.streaming.api.datastream.{AllWindowedStream, DataStream, KeyedStream, WindowedStream} -import org.apache.flink.streaming.api.functions.windowing.{WindowFunction, AllWindowFunction} import org.apache.flink.streaming.api.windowing.assigners._ import org.apache.flink.streaming.api.windowing.time.Time import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} @@ -103,30 +101,24 @@ class DataStreamAggregate( } override def translateToPlan( - tableEnv: StreamTableEnvironment, - expectedType: Option[TypeInformation[Any]]) - : DataStream[Any] = { - - val config = tableEnv.getConfig + tableEnv: StreamTableEnvironment, + expectedType: Option[TypeInformation[Any]]): DataStream[Any] = { + val config = tableEnv.getConfig val groupingKeys = grouping.indices.toArray - // add grouping fields, position keys in the input, and input type - val aggregateResult = AggregateUtil.createOperatorFunctionsForAggregates( - namedAggregates, - inputType, - getRowType, - grouping) - val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan( tableEnv, // tell the input operator that this operator currently only supports Rows as input Some(TypeConverter.DEFAULT_ROW_TYPE)) // get the output types - val fieldTypes: Array[TypeInformation[_]] = getRowType.getFieldList.asScala + val fieldTypes: Array[TypeInformation[_]] = + getRowType.getFieldList.asScala .map(field => FlinkTypeFactory.toTypeInfo(field.getType)) .toArray + val rowTypeInfo = new RowTypeInfo(fieldTypes) + val aggString = aggregationToString( inputType, grouping, @@ -135,50 +127,118 @@ class DataStreamAggregate( namedProperties) val prepareOpName = s"prepare select: ($aggString)" - val mappedInput = inputDS - .map(aggregateResult._1) - .name(prepareOpName) - - val groupReduceFunction = aggregateResult._2 - val rowTypeInfo = new RowTypeInfo(fieldTypes) + val keyedAggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " + + s"window: ($window), " + + s"select: ($aggString)" + val nonKeyedAggOpName = s"window: ($window), select: ($aggString)" - val result = { - // grouped / keyed aggregation - if (groupingKeys.length > 0) { - val aggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " + - s"window: ($window), " + - s"select: ($aggString)" - val aggregateFunction = - createWindowAggregationFunction(window, namedProperties, groupReduceFunction) - - val keyedStream = mappedInput.keyBy(groupingKeys: _*) - - val windowedStream = createKeyedWindowedStream(window, keyedStream) - .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]] - - windowedStream - .apply(aggregateFunction) + val mapFunction = AggregateUtil.createPrepareMapFunction( + namedAggregates, + grouping, + inputType) + + val mappedInput = inputDS.map(mapFunction).name(prepareOpName) + + val result: DataStream[Any] = { + // check whether all aggregates support partial aggregate + if (AggregateUtil.doAllSupportPartialAggregation( + namedAggregates.map(_.getKey), + inputType, + grouping.length)) { + // do Incremental Aggregation + val reduceFunction = AggregateUtil.createIncrementalAggregateReduceFunction( + namedAggregates, + inputType, + getRowType, + grouping) + // grouped / keyed aggregation + if (groupingKeys.length > 0) { + val windowFunction = AggregateUtil.createWindowIncrementalAggregationFunction( + window, + namedAggregates, + inputType, + rowRelDataType, + grouping, + namedProperties) + + val keyedStream = mappedInput.keyBy(groupingKeys: _*) + val windowedStream = + createKeyedWindowedStream(window, keyedStream) + .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]] + + windowedStream + .apply(reduceFunction, windowFunction) + .returns(rowTypeInfo) + .name(keyedAggOpName) + .asInstanceOf[DataStream[Any]] + } + // global / non-keyed aggregation + else { + val windowFunction = AggregateUtil.createAllWindowIncrementalAggregationFunction( + window, + namedAggregates, + inputType, + rowRelDataType, + grouping, + namedProperties) + + val windowedStream = + createNonKeyedWindowedStream(window, mappedInput) + .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]] + + windowedStream + .apply(reduceFunction, windowFunction) .returns(rowTypeInfo) - .name(aggOpName) + .name(nonKeyedAggOpName) .asInstanceOf[DataStream[Any]] + } } - // global / non-keyed aggregation else { - val aggOpName = s"window: ($window), select: ($aggString)" - val aggregateFunction = - createAllWindowAggregationFunction(window, namedProperties, groupReduceFunction) - - val windowedStream = createNonKeyedWindowedStream(window, mappedInput) - .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]] - - windowedStream - .apply(aggregateFunction) + // do non-Incremental Aggregation + // grouped / keyed aggregation + if (groupingKeys.length > 0) { + + val windowFunction = AggregateUtil.createWindowAggregationFunction( + window, + namedAggregates, + inputType, + rowRelDataType, + grouping, + namedProperties) + + val keyedStream = mappedInput.keyBy(groupingKeys: _*) + val windowedStream = + createKeyedWindowedStream(window, keyedStream) + .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]] + + windowedStream + .apply(windowFunction) .returns(rowTypeInfo) - .name(aggOpName) + .name(keyedAggOpName) .asInstanceOf[DataStream[Any]] + } + // global / non-keyed aggregation + else { + val windowFunction = AggregateUtil.createAllWindowAggregationFunction( + window, + namedAggregates, + inputType, + rowRelDataType, + grouping, + namedProperties) + + val windowedStream = + createNonKeyedWindowedStream(window, mappedInput) + .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]] + + windowedStream + .apply(windowFunction) + .returns(rowTypeInfo) + .name(nonKeyedAggOpName) + .asInstanceOf[DataStream[Any]] + } } } - // if the expected type is not a Row, inject a mapper to convert to the expected type expectedType match { case Some(typeInfo) if typeInfo.getTypeClass != classOf[Row] => @@ -196,72 +256,8 @@ class DataStreamAggregate( } } } - object DataStreamAggregate { - private def createAllWindowAggregationFunction( - window: LogicalWindow, - properties: Seq[NamedWindowProperty], - aggFunction: RichGroupReduceFunction[Row, Row]) - : AllWindowFunction[Row, Row, DataStreamWindow] = { - - if (isTimeWindow(window)) { - val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) - new AggregateAllTimeWindowFunction(aggFunction, startPos, endPos) - .asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]] - } else { - new AggregateAllWindowFunction(aggFunction) - } - - } - - private def createWindowAggregationFunction( - window: LogicalWindow, - properties: Seq[NamedWindowProperty], - aggFunction: RichGroupReduceFunction[Row, Row]) - : WindowFunction[Row, Row, Tuple, DataStreamWindow] = { - - if (isTimeWindow(window)) { - val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) - new AggregateTimeWindowFunction(aggFunction, startPos, endPos) - .asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]] - } else { - new AggregateWindowFunction(aggFunction) - } - - } - - private def isTimeWindow(window: LogicalWindow) = { - window match { - case ProcessingTimeTumblingGroupWindow(_, size) => isTimeInterval(size.resultType) - case ProcessingTimeSlidingGroupWindow(_, size, _) => isTimeInterval(size.resultType) - case ProcessingTimeSessionGroupWindow(_, _) => true - case EventTimeTumblingGroupWindow(_, _, size) => isTimeInterval(size.resultType) - case EventTimeSlidingGroupWindow(_, _, size, _) => isTimeInterval(size.resultType) - case EventTimeSessionGroupWindow(_, _, _) => true - } - } - - def computeWindowStartEndPropertyPos(properties: Seq[NamedWindowProperty]) - : (Option[Int], Option[Int]) = { - - val propPos = properties.foldRight((None: Option[Int], None: Option[Int], 0)) { - (p, x) => p match { - case NamedWindowProperty(name, prop) => - prop match { - case WindowStart(_) if x._1.isDefined => - throw new TableException("Duplicate WindowStart property encountered. This is a bug.") - case WindowStart(_) => - (Some(x._3), x._2, x._3 - 1) - case WindowEnd(_) if x._2.isDefined => - throw new TableException("Duplicate WindowEnd property encountered. This is a bug.") - case WindowEnd(_) => - (x._1, Some(x._3), x._3 - 1) - } - } - } - (propPos._1, propPos._2) - } private def createKeyedWindowedStream(groupWindow: LogicalWindow, stream: KeyedStream[Row, Tuple]) : WindowedStream[Row, Tuple, _ <: DataStreamWindow] = groupWindow match { http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala index ceadfe7..7ace2c5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala @@ -31,14 +31,13 @@ class AggregateAllTimeWindowFunction( groupReduceFunction: RichGroupReduceFunction[Row, Row], windowStartPos: Option[Int], windowEndPos: Option[Int]) - - extends RichAllWindowFunction[Row, Row, TimeWindow] { + extends AggregateAllWindowFunction[TimeWindow](groupReduceFunction) { private var collector: TimeWindowPropertyCollector = _ override def open(parameters: Configuration): Unit = { - groupReduceFunction.open(parameters) collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos) + super.open(parameters) } override def apply(window: TimeWindow, input: Iterable[Row], out: Collector[Row]): Unit = { @@ -48,6 +47,6 @@ class AggregateAllTimeWindowFunction( collector.timeWindow = window // call wrapped reduce function with property collector - groupReduceFunction.reduce(input, collector) + super.apply(window, input, collector) } } http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateAllWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateAllWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateAllWindowFunction.scala index 53ab948..4b045be 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateAllWindowFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateAllWindowFunction.scala @@ -27,14 +27,15 @@ import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction import org.apache.flink.streaming.api.windowing.windows.Window import org.apache.flink.util.Collector -class AggregateAllWindowFunction(groupReduceFunction: RichGroupReduceFunction[Row, Row]) - extends RichAllWindowFunction[Row, Row, Window] { +class AggregateAllWindowFunction[W <: Window]( + groupReduceFunction: RichGroupReduceFunction[Row, Row]) + extends RichAllWindowFunction[Row, Row, W] { override def open(parameters: Configuration): Unit = { groupReduceFunction.open(parameters) } - override def apply(window: Window, input: Iterable[Row], out: Collector[Row]): Unit = { + override def apply(window: W, input: Iterable[Row], out: Collector[Row]): Unit = { groupReduceFunction.reduce(input, out) } } http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateMapFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateMapFunction.scala index d848d21..7559cec 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateMapFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateMapFunction.scala @@ -29,7 +29,8 @@ class AggregateMapFunction[IN, OUT]( private val aggFields: Array[Int], private val groupingKeys: Array[Int], @transient private val returnType: TypeInformation[OUT]) - extends RichMapFunction[IN, OUT] with ResultTypeQueryable[OUT] { + extends RichMapFunction[IN, OUT] + with ResultTypeQueryable[OUT] { private var output: Row = _ http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateReduceCombineFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateReduceCombineFunction.scala index ca074cc..ebf0ca7 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateReduceCombineFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateReduceCombineFunction.scala @@ -45,53 +45,13 @@ class AggregateReduceCombineFunction( private val aggregateMapping: Array[(Int, Int)], private val intermediateRowArity: Int, private val finalRowArity: Int) - extends RichGroupReduceFunction[Row, Row] with CombineFunction[Row, Row] { - - private var aggregateBuffer: Row = _ - private var output: Row = _ - - override def open(config: Configuration): Unit = { - Preconditions.checkNotNull(aggregates) - Preconditions.checkNotNull(groupKeysMapping) - aggregateBuffer = new Row(intermediateRowArity) - output = new Row(finalRowArity) - } - - /** - * For grouped intermediate aggregate Rows, merge all of them into aggregate buffer, - * calculate aggregated values output by aggregate buffer, and set them into output - * Row based on the mapping relation between intermediate aggregate Row and output Row. - * - * @param records Grouped intermediate aggregate Rows iterator. - * @param out The collector to hand results to. - * - */ - override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { - - // Initiate intermediate aggregate value. - aggregates.foreach(_.initiate(aggregateBuffer)) - - // Merge intermediate aggregate value to buffer. - var last: Row = null - records.foreach((record) => { - aggregates.foreach(_.merge(record, aggregateBuffer)) - last = record - }) - - // Set group keys value to final output. - groupKeysMapping.foreach { - case (after, previous) => - output.setField(after, last.productElement(previous)) - } - - // Evaluate final aggregate value and set to output. - aggregateMapping.foreach { - case (after, previous) => - output.setField(after, aggregates(previous).evaluate(aggregateBuffer)) - } - - out.collect(output) - } + extends AggregateReduceGroupFunction( + aggregates, + groupKeysMapping, + aggregateMapping, + intermediateRowArity, + finalRowArity) + with CombineFunction[Row, Row] { /** * For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer, http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateReduceGroupFunction.scala index d81f3a1..8f096cc 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateReduceGroupFunction.scala @@ -42,9 +42,9 @@ class AggregateReduceGroupFunction( private val aggregateMapping: Array[(Int, Int)], private val intermediateRowArity: Int, private val finalRowArity: Int) - extends RichGroupReduceFunction[Row, Row] { + extends RichGroupReduceFunction[Row, Row] { - private var aggregateBuffer: Row = _ + protected var aggregateBuffer: Row = _ private var output: Row = _ override def open(config: Configuration) { http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTimeWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTimeWindowFunction.scala index 80f52ca..9b7ea0b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTimeWindowFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTimeWindowFunction.scala @@ -32,26 +32,26 @@ class AggregateTimeWindowFunction( groupReduceFunction: RichGroupReduceFunction[Row, Row], windowStartPos: Option[Int], windowEndPos: Option[Int]) - extends RichWindowFunction[Row, Row, Tuple, TimeWindow] { + extends AggregateWindowFunction[TimeWindow](groupReduceFunction) { private var collector: TimeWindowPropertyCollector = _ override def open(parameters: Configuration): Unit = { - groupReduceFunction.open(parameters) collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos) + super.open(parameters) } override def apply( - key: Tuple, - window: TimeWindow, - input: Iterable[Row], - out: Collector[Row]) : Unit = { + key: Tuple, + window: TimeWindow, + input: Iterable[Row], + out: Collector[Row]): Unit = { // set collector and window collector.wrappedCollector = out collector.timeWindow = window // call wrapped reduce function with property collector - groupReduceFunction.reduce(input, collector) + super.apply(key, window, input, collector) } } http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala index 903cc07..4428963 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala @@ -27,9 +27,15 @@ import org.apache.calcite.sql.`type`.{SqlTypeFactoryImpl, SqlTypeName} import org.apache.calcite.sql.fun._ import org.apache.flink.api.common.functions.{MapFunction, RichGroupReduceFunction} import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.tuple.Tuple +import org.apache.flink.api.table.FlinkRelBuilder.NamedWindowProperty +import org.apache.flink.api.table.expressions.{WindowEnd, WindowStart} +import org.apache.flink.api.table.plan.logical._ import org.apache.flink.api.table.typeutils.RowTypeInfo +import org.apache.flink.api.table.typeutils.TypeCheckUtils._ import org.apache.flink.api.table.{FlinkTypeFactory, Row, TableException} - +import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction} +import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer @@ -39,66 +45,76 @@ object AggregateUtil { type JavaList[T] = java.util.List[T] /** - * Create Flink operator functions for aggregates. It includes 2 implementations of Flink - * operator functions: - * [[org.apache.flink.api.common.functions.MapFunction]] and - * [[org.apache.flink.api.common.functions.GroupReduceFunction]](if it's partial aggregate, - * should also implement [[org.apache.flink.api.common.functions.CombineFunction]] as well). - * The output of [[org.apache.flink.api.common.functions.MapFunction]] contains the - * intermediate aggregate values of all aggregate function, it's stored in Row by the following - * format: - * - * {{{ - * avg(x) aggOffsetInRow = 2 count(z) aggOffsetInRow = 5 - * | | - * v v - * +---------+---------+--------+--------+--------+--------+ - * |groupKey1|groupKey2| sum1 | count1 | sum2 | count2 | - * +---------+---------+--------+--------+--------+--------+ - * ^ - * | - * sum(y) aggOffsetInRow = 4 - * }}} - * - */ - def createOperatorFunctionsForAggregates( - namedAggregates: Seq[CalcitePair[AggregateCall, String]], - inputType: RelDataType, - outputType: RelDataType, - groupings: Array[Int]) - : (MapFunction[Any, Row], RichGroupReduceFunction[Row, Row]) = { - - val aggregateFunctionsAndFieldIndexes = - transformToAggregateFunctions(namedAggregates.map(_.getKey), inputType, groupings.length) - // store the aggregate fields of each aggregate function, by the same order of aggregates. - val aggFieldIndexes = aggregateFunctionsAndFieldIndexes._1 - val aggregates = aggregateFunctionsAndFieldIndexes._2 + * Create a [[org.apache.flink.api.common.functions.MapFunction]] that prepares for aggregates. + * The function returns intermediate aggregate values of all aggregate function which are + * organized by the following format: + * + * {{{ + * avg(x) aggOffsetInRow = 2 count(z) aggOffsetInRow = 5 + * | | + * v v + * +---------+---------+--------+--------+--------+--------+ + * |groupKey1|groupKey2| sum1 | count1 | sum2 | count2 | + * +---------+---------+--------+--------+--------+--------+ + * ^ + * | + * sum(y) aggOffsetInRow = 4 + * }}} + * + */ + private[flink] def createPrepareMapFunction( + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + groupings: Array[Int], + inputType: RelDataType): MapFunction[Any, Row] = { + + val (aggFieldIndexes,aggregates) = transformToAggregateFunctions( + namedAggregates.map(_.getKey), + inputType, + groupings.length) val mapReturnType: RowTypeInfo = createAggregateBufferDataType(groupings, aggregates, inputType) val mapFunction = new AggregateMapFunction[Row, Row]( - aggregates, aggFieldIndexes, groupings, - mapReturnType.asInstanceOf[RowTypeInfo]).asInstanceOf[MapFunction[Any, Row]] - - // the mapping relation between field index of intermediate aggregate Row and output Row. - val groupingOffsetMapping = getGroupKeysMapping(inputType, outputType, groupings) - - // the mapping relation between aggregate function index in list and its corresponding - // field index in output Row. - val aggOffsetMapping = getAggregateMapping(namedAggregates, outputType) - - if (groupingOffsetMapping.length != groupings.length || - aggOffsetMapping.length != namedAggregates.length) { - throw new TableException("Could not find output field in input data type " + - "or aggregate functions.") - } - - val allPartialAggregate = aggregates.map(_.supportPartial).forall(x => x) + aggregates, + aggFieldIndexes, + groupings, + mapReturnType.asInstanceOf[RowTypeInfo]).asInstanceOf[MapFunction[Any, Row]] - val intermediateRowArity = groupings.length + aggregates.map(_.intermediateDataType.length).sum + mapFunction + } - val reduceGroupFunction = + /** + * Create a [[org.apache.flink.api.common.functions.GroupReduceFunction]] to compute aggregates. + * If all aggregates support partial aggregation, the + * [[org.apache.flink.api.common.functions.GroupReduceFunction]] implements + * [[org.apache.flink.api.common.functions.CombineFunction]] as well. + * + */ + private[flink] def createAggregateGroupReduceFunction( + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int]): RichGroupReduceFunction[Row, Row] = { + + val aggregates = transformToAggregateFunctions( + namedAggregates.map(_.getKey), + inputType, + groupings.length)._2 + + val (groupingOffsetMapping, aggOffsetMapping) = + getGroupingOffsetAndAggOffsetMapping( + namedAggregates, + inputType, + outputType, + groupings) + + val allPartialAggregate: Boolean = aggregates.forall(_.supportPartial) + + val intermediateRowArity = groupings.length + + aggregates.map(_.intermediateDataType.length).sum + + val groupReduceFunction = if (allPartialAggregate) { new AggregateReduceCombineFunction( aggregates, @@ -115,14 +131,257 @@ object AggregateUtil { intermediateRowArity, outputType.getFieldCount) } + groupReduceFunction + } + + /** + * Create a [[org.apache.flink.api.common.functions.ReduceFunction]] for incremental window + * aggregation. + * + */ + private[flink] def createIncrementalAggregateReduceFunction( + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int]): IncrementalAggregateReduceFunction = { + + val aggregates = transformToAggregateFunctions( + namedAggregates.map(_.getKey),inputType,groupings.length)._2 + + val groupingOffsetMapping = + getGroupingOffsetAndAggOffsetMapping( + namedAggregates, + inputType, + outputType, + groupings)._1 + + val intermediateRowArity = groupings.length + aggregates.map(_.intermediateDataType.length).sum + val reduceFunction = new IncrementalAggregateReduceFunction( + aggregates, + groupingOffsetMapping, + intermediateRowArity) + reduceFunction + } + + /** + * Create an [[AllWindowFunction]] to compute non-partitioned group window aggregates. + */ + private[flink] def createAllWindowAggregationFunction( + window: LogicalWindow, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int], + properties: Seq[NamedWindowProperty]) + : AllWindowFunction[Row, Row, DataStreamWindow] = { + + val aggFunction = + createAggregateGroupReduceFunction( + namedAggregates, + inputType, + outputType, + groupings) + + if (isTimeWindow(window)) { + val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) + new AggregateAllTimeWindowFunction(aggFunction, startPos, endPos) + .asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]] + } else { + new AggregateAllWindowFunction(aggFunction) + } + } + + /** + * Create a [[WindowFunction]] to compute partitioned group window aggregates. + * + */ + private[flink] def createWindowAggregationFunction( + window: LogicalWindow, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int], + properties: Seq[NamedWindowProperty]) + : WindowFunction[Row, Row, Tuple, DataStreamWindow] = { + + val aggFunction = + createAggregateGroupReduceFunction( + namedAggregates, + inputType, + outputType, + groupings) + + if (isTimeWindow(window)) { + val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) + new AggregateTimeWindowFunction(aggFunction, startPos, endPos) + .asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]] + } else { + new AggregateWindowFunction(aggFunction) + } + } - (mapFunction, reduceGroupFunction) + /** + * Create an [[AllWindowFunction]] to finalize incrementally pre-computed non-partitioned + * window aggreagtes. + */ + private[flink] def createAllWindowIncrementalAggregationFunction( + window: LogicalWindow, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int], + properties: Seq[NamedWindowProperty]): AllWindowFunction[Row, Row, DataStreamWindow] = { + + val aggregates = transformToAggregateFunctions( + namedAggregates.map(_.getKey),inputType,groupings.length)._2 + + val (groupingOffsetMapping, aggOffsetMapping) = + getGroupingOffsetAndAggOffsetMapping( + namedAggregates, + inputType, + outputType, + groupings) + + val finalRowArity = outputType.getFieldCount + + if (isTimeWindow(window)) { + val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) + new IncrementalAggregateAllTimeWindowFunction( + aggregates, + groupingOffsetMapping, + aggOffsetMapping, + finalRowArity, + startPos, + endPos) + .asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]] + } else { + new IncrementalAggregateAllWindowFunction( + aggregates, + groupingOffsetMapping, + aggOffsetMapping, + finalRowArity) + } + } + + /** + * Create a [[WindowFunction]] to finalize incrementally pre-computed window aggregates. + */ + private[flink] def createWindowIncrementalAggregationFunction( + window: LogicalWindow, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int], + properties: Seq[NamedWindowProperty]): WindowFunction[Row, Row, Tuple, DataStreamWindow] = { + + val aggregates = transformToAggregateFunctions( + namedAggregates.map(_.getKey),inputType,groupings.length)._2 + + val (groupingOffsetMapping, aggOffsetMapping) = + getGroupingOffsetAndAggOffsetMapping( + namedAggregates, + inputType, + outputType, + groupings) + + val finalRowArity = outputType.getFieldCount + + if (isTimeWindow(window)) { + val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) + new IncrementalAggregateTimeWindowFunction( + aggregates, + groupingOffsetMapping, + aggOffsetMapping, + finalRowArity, + startPos, + endPos) + .asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]] + } else { + new IncrementalAggregateWindowFunction( + aggregates, + groupingOffsetMapping, + aggOffsetMapping, + finalRowArity) + } + } + + /** + * Return true if all aggregates can be partially computed. False otherwise. + */ + private[flink] def doAllSupportPartialAggregation( + aggregateCalls: Seq[AggregateCall], + inputType: RelDataType, + groupKeysCount: Int): Boolean = { + transformToAggregateFunctions( + aggregateCalls, + inputType, + groupKeysCount)._2.forall(_.supportPartial) + } + + /** + * @return groupingOffsetMapping (mapping relation between field index of intermediate + * aggregate Row and output Row.) + * and aggOffsetMapping (the mapping relation between aggregate function index in list + * and its corresponding field index in output Row.) + */ + private def getGroupingOffsetAndAggOffsetMapping( + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int]): (Array[(Int, Int)], Array[(Int, Int)]) = { + + // the mapping relation between field index of intermediate aggregate Row and output Row. + val groupingOffsetMapping = getGroupKeysMapping(inputType, outputType, groupings) + + // the mapping relation between aggregate function index in list and its corresponding + // field index in output Row. + val aggOffsetMapping = getAggregateMapping(namedAggregates, outputType) + + if (groupingOffsetMapping.length != groupings.length || + aggOffsetMapping.length != namedAggregates.length) { + throw new TableException( + "Could not find output field in input data type " + + "or aggregate functions.") + } + (groupingOffsetMapping, aggOffsetMapping) + } + + private def isTimeWindow(window: LogicalWindow) = { + window match { + case ProcessingTimeTumblingGroupWindow(_, size) => isTimeInterval(size.resultType) + case ProcessingTimeSlidingGroupWindow(_, size, _) => isTimeInterval(size.resultType) + case ProcessingTimeSessionGroupWindow(_, _) => true + case EventTimeTumblingGroupWindow(_, _, size) => isTimeInterval(size.resultType) + case EventTimeSlidingGroupWindow(_, _, size, _) => isTimeInterval(size.resultType) + case EventTimeSessionGroupWindow(_, _, _) => true + } + } + + private def computeWindowStartEndPropertyPos( + properties: Seq[NamedWindowProperty]): (Option[Int], Option[Int]) = { + + val propPos = properties.foldRight((None: Option[Int], None: Option[Int], 0)) { + (p, x) => p match { + case NamedWindowProperty(name, prop) => + prop match { + case WindowStart(_) if x._1.isDefined => + throw new TableException("Duplicate WindowStart property encountered. This is a bug.") + case WindowStart(_) => + (Some(x._3), x._2, x._3 - 1) + case WindowEnd(_) if x._2.isDefined => + throw new TableException("Duplicate WindowEnd property encountered. This is a bug.") + case WindowEnd(_) => + (x._1, Some(x._3), x._3 - 1) + } + } + } + (propPos._1, propPos._2) } private def transformToAggregateFunctions( - aggregateCalls: Seq[AggregateCall], - inputType: RelDataType, - groupKeysCount: Int): (Array[Int], Array[Aggregate[_ <: Any]]) = { + aggregateCalls: Seq[AggregateCall], + inputType: RelDataType, + groupKeysCount: Int): (Array[Int], Array[Aggregate[_ <: Any]]) = { // store the aggregate fields of each aggregate function, by the same order of aggregates. val aggFieldIndexes = new Array[Int](aggregateCalls.size) @@ -253,9 +512,9 @@ object AggregateUtil { } private def createAggregateBufferDataType( - groupings: Array[Int], - aggregates: Array[Aggregate[_]], - inputType: RelDataType): RowTypeInfo = { + groupings: Array[Int], + aggregates: Array[Aggregate[_]], + inputType: RelDataType): RowTypeInfo = { // get the field data types of group keys. val groupingTypes: Seq[TypeInformation[_]] = groupings @@ -275,8 +534,9 @@ object AggregateUtil { } // Find the mapping between the index of aggregate list and aggregated value index in output Row. - private def getAggregateMapping(namedAggregates: Seq[CalcitePair[AggregateCall, String]], - outputType: RelDataType): Array[(Int, Int)] = { + private def getAggregateMapping( + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + outputType: RelDataType): Array[(Int, Int)] = { // the mapping relation between aggregate function index in list and its corresponding // field index in output Row. @@ -298,8 +558,10 @@ object AggregateUtil { // Find the mapping between the index of group key in intermediate aggregate Row and its index // in output Row. - private def getGroupKeysMapping(inputDatType: RelDataType, - outputType: RelDataType, groupKeys: Array[Int]): Array[(Int, Int)] = { + private def getGroupKeysMapping( + inputDatType: RelDataType, + outputType: RelDataType, + groupKeys: Array[Int]): Array[(Int, Int)] = { // the mapping relation between field index of intermediate aggregate Row and output Row. var groupingOffsetMapping = ArrayBuffer[(Int, Int)]() http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateWindowFunction.scala index 180248f..6fd890d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateWindowFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateWindowFunction.scala @@ -28,18 +28,18 @@ import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction import org.apache.flink.streaming.api.windowing.windows.Window import org.apache.flink.util.Collector -class AggregateWindowFunction(groupReduceFunction: RichGroupReduceFunction[Row, Row]) - extends RichWindowFunction[Row, Row, Tuple, Window] { +class AggregateWindowFunction[W <: Window](groupReduceFunction: RichGroupReduceFunction[Row, Row]) + extends RichWindowFunction[Row, Row, Tuple, W] { override def open(parameters: Configuration): Unit = { groupReduceFunction.open(parameters) } override def apply( - key: Tuple, - window: Window, - input: Iterable[Row], - out: Collector[Row]) : Unit = { + key: Tuple, + window: W, + input: Iterable[Row], + out: Collector[Row]): Unit = { groupReduceFunction.reduce(input, out) } http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala new file mode 100644 index 0000000..85ad8e5 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.table.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.windowing.windows.{TimeWindow, Window} +import org.apache.flink.util.Collector +/** + * + * Computes the final aggregate value from incrementally computed aggreagtes. + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated value + * index in output Row. + * @param finalRowArity The arity of the final output row. + */ +class IncrementalAggregateAllTimeWindowFunction( + private val aggregates: Array[Aggregate[_ <: Any]], + private val groupKeysMapping: Array[(Int, Int)], + private val aggregateMapping: Array[(Int, Int)], + private val finalRowArity: Int, + private val windowStartPos: Option[Int], + private val windowEndPos: Option[Int]) + extends IncrementalAggregateAllWindowFunction[TimeWindow]( + aggregates, + groupKeysMapping, + aggregateMapping, + finalRowArity) { + + private var collector: TimeWindowPropertyCollector = _ + + override def open(parameters: Configuration): Unit = { + collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos) + super.open(parameters) + } + + override def apply( + window: TimeWindow, + records: Iterable[Row], + out: Collector[Row]): Unit = { + + // set collector and window + collector.wrappedCollector = out + collector.timeWindow = window + + super.apply(window, records, collector) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala new file mode 100644 index 0000000..d3f871a --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.table.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction +import org.apache.flink.streaming.api.windowing.windows.Window +import org.apache.flink.util.{Collector, Preconditions} + +/** + * Computes the final aggregate value from incrementally computed aggreagtes. + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated value + * index in output Row. + * @param finalRowArity The arity of the final output row. + */ +class IncrementalAggregateAllWindowFunction[W <: Window]( + private val aggregates: Array[Aggregate[_ <: Any]], + private val groupKeysMapping: Array[(Int, Int)], + private val aggregateMapping: Array[(Int, Int)], + private val finalRowArity: Int) + extends RichAllWindowFunction[Row, Row, W] { + + private var output: Row = _ + + override def open(parameters: Configuration): Unit = { + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(groupKeysMapping) + output = new Row(finalRowArity) + } + + /** + * Calculate aggregated values output by aggregate buffer, and set them into output + * Row based on the mapping relation between intermediate aggregate data and output data. + */ + override def apply( + window: W, + records: Iterable[Row], + out: Collector[Row]): Unit = { + + val iterator = records.iterator + + if (iterator.hasNext) { + val record = iterator.next() + // Set group keys value to final output. + groupKeysMapping.foreach { + case (after, previous) => + output.setField(after, record.productElement(previous)) + } + // Evaluate final aggregate value and set to output. + aggregateMapping.foreach { + case (after, previous) => + output.setField(after, aggregates(previous).evaluate(record)) + } + out.collect(output) + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala new file mode 100644 index 0000000..e2830da --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.runtime.aggregate + +import org.apache.flink.api.common.functions.ReduceFunction +import org.apache.flink.api.table.Row +import org.apache.flink.util.Preconditions + +/** + * Incrementally computes group window aggregates. + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + */ +class IncrementalAggregateReduceFunction( + private val aggregates: Array[Aggregate[_]], + private val groupKeysMapping: Array[(Int, Int)], + private val intermediateRowArity: Int) + extends ReduceFunction[Row] { + + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(groupKeysMapping) + + /** + * For Incremental intermediate aggregate Rows, merge value1 and value2 + * into aggregate buffer, return aggregate buffer. + * + * @param value1 The first value to combined. + * @param value2 The second value to combined. + * @return accumulatorRow A resulting row that combines two input values. + * + */ + override def reduce(value1: Row, value2: Row): Row = { + + // TODO: once FLINK-5105 is solved, we can avoid creating a new row for each invocation + // and directly merge value1 and value2. + val accumulatorRow = new Row(intermediateRowArity) + + // copy all fields of value1 into accumulatorRow + (0 until intermediateRowArity) + .foreach(i => accumulatorRow.setField(i, value1.productElement(i))) + // merge value2 to accumulatorRow + aggregates.foreach(_.merge(value2, accumulatorRow)) + + accumulatorRow + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala new file mode 100644 index 0000000..c880f87 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.java.tuple.Tuple +import org.apache.flink.api.table.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.windowing.windows.TimeWindow +import org.apache.flink.util.Collector + +/** + * Computes the final aggregate value from incrementally computed aggreagtes. + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated value + * index in output Row. + * @param finalRowArity The arity of the final output row. + */ +class IncrementalAggregateTimeWindowFunction( + private val aggregates: Array[Aggregate[_ <: Any]], + private val groupKeysMapping: Array[(Int, Int)], + private val aggregateMapping: Array[(Int, Int)], + private val finalRowArity: Int, + private val windowStartPos: Option[Int], + private val windowEndPos: Option[Int]) + extends IncrementalAggregateWindowFunction[TimeWindow]( + aggregates, + groupKeysMapping, + aggregateMapping, finalRowArity) { + + private var collector: TimeWindowPropertyCollector = _ + + override def open(parameters: Configuration): Unit = { + collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos) + super.open(parameters) + } + + override def apply( + key: Tuple, + window: TimeWindow, + records: Iterable[Row], + out: Collector[Row]): Unit = { + + // set collector and window + collector.wrappedCollector = out + collector.timeWindow = window + + super.apply(key, window, records, collector) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala new file mode 100644 index 0000000..81e6890 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.java.tuple.Tuple +import org.apache.flink.api.table.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction +import org.apache.flink.streaming.api.windowing.windows.Window +import org.apache.flink.util.{Collector, Preconditions} + +/** + * Computes the final aggregate value from incrementally computed aggreagtes. + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated value + * index in output Row. + * @param finalRowArity The arity of the final output row. + */ +class IncrementalAggregateWindowFunction[W <: Window]( + private val aggregates: Array[Aggregate[_ <: Any]], + private val groupKeysMapping: Array[(Int, Int)], + private val aggregateMapping: Array[(Int, Int)], + private val finalRowArity: Int) + extends RichWindowFunction[Row, Row, Tuple, W] { + + private var output: Row = _ + + override def open(parameters: Configuration): Unit = { + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(groupKeysMapping) + output = new Row(finalRowArity) + } + + /** + * Calculate aggregated values output by aggregate buffer, and set them into output + * Row based on the mapping relation between intermediate aggregate data and output data. + */ + override def apply( + key: Tuple, + window: W, + records: Iterable[Row], + out: Collector[Row]): Unit = { + + val iterator = records.iterator + + if (iterator.hasNext) { + val record = iterator.next() + // Set group keys value to final output. + groupKeysMapping.foreach { + case (after, previous) => + output.setField(after, record.productElement(previous)) + } + // Evaluate final aggregate value and set to output. + aggregateMapping.foreach { + case (after, previous) => + output.setField(after, aggregates(previous).evaluate(record)) + } + out.collect(output) + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/74e0971a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/AggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/AggregationsITCase.scala index 2ccbb38..0753484 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/AggregationsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/AggregationsITCase.scala @@ -59,13 +59,13 @@ class AggregationsITCase extends StreamingMultipleProgramsTestBase { val windowedTable = table .groupBy('string) .window(Slide over 2.rows every 1.rows) - .select('string, 'int.count) + .select('string, 'int.count, 'int.avg) val results = windowedTable.toDataStream[Row] results.addSink(new StreamITCase.StringSink) env.execute() - val expected = Seq("Hello world,1", "Hello world,2", "Hello,1", "Hello,2", "Hi,1") + val expected = Seq("Hello world,1,3", "Hello world,2,3", "Hello,1,2", "Hello,2,2", "Hi,1,1") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -131,17 +131,17 @@ class AggregationsITCase extends StreamingMultipleProgramsTestBase { val windowedTable = table .groupBy('string) .window(Tumble over 5.milli on 'rowtime as 'w) - .select('string, 'int.count, 'w.start, 'w.end) + .select('string, 'int.count, 'int.avg, 'w.start, 'w.end) val results = windowedTable.toDataStream[Row] results.addSink(new StreamITCase.StringSink) env.execute() val expected = Seq( - "Hello world,1,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01", - "Hello world,1,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02", - "Hello,2,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005", - "Hi,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005") + "Hello world,1,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01", + "Hello world,1,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02", + "Hello,2,2,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005", + "Hi,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005") assertEquals(expected.sorted, StreamITCase.testResults.sorted) }
