http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index cd473ee..40468ad 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -25,8 +25,8 @@ import org.apache.calcite.sql.{SqlAggFunction, SqlKind} import org.apache.calcite.sql.`type`.SqlTypeName._ import org.apache.calcite.sql.`type`.SqlTypeName import org.apache.calcite.sql.fun._ -import org.apache.flink.api.common.functions.{MapFunction, RichGroupReduceFunction,RichGroupCombineFunction} -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.common.functions.{InvalidTypesException, MapFunction, RichGroupCombineFunction, RichGroupReduceFunction, AggregateFunction => ApiAggregateFunction} +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeHint, TypeInformation} import org.apache.flink.api.java.tuple.Tuple import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory} @@ -37,6 +37,9 @@ import org.apache.flink.table.typeutils.TypeCheckUtils._ import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction} import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} import org.apache.flink.table.api.{TableException, Types} +import org.apache.flink.table.functions.aggfunctions._ +import org.apache.flink.table.functions.{AggregateFunction => TableAggregateFunction} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo} import org.apache.flink.types.Row @@ -54,15 +57,15 @@ object AggregateUtil { * organized by the following format: * * {{{ - * avg(x) aggOffsetInRow = 2 count(z) aggOffsetInRow = 5 - * | | - * v v - * +---------+---------+--------+--------+--------+--------+ - * |groupKey1|groupKey2| sum1 | count1 | sum2 | count2 | - * +---------+---------+--------+--------+--------+--------+ + * avg(x) count(z) + * | | + * v v + * +---------+---------+-----------------+------------------+------------------+ + * |groupKey1|groupKey2| AvgAccumulator | SumAccumulator | CountAccumulator | + * +---------+---------+-----------------+------------------+------------------+ * ^ * | - * sum(y) aggOffsetInRow = 4 + * sum(y) * }}} * */ @@ -70,15 +73,15 @@ object AggregateUtil { namedAggregates: Seq[CalcitePair[AggregateCall, String]], groupings: Array[Int], inputType: RelDataType) - : MapFunction[Row, Row] = { + : MapFunction[Row, Row] = { - val (aggFieldIndexes,aggregates) = transformToAggregateFunctions( + val (aggFieldIndexes, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, groupings.length) val mapReturnType: RowTypeInfo = - createAggregateBufferDataType(groupings, aggregates, inputType) + createDataSetAggregateBufferDataType(groupings, aggregates, inputType) val mapFunction = new AggregateMapFunction[Row, Row]( aggregates, @@ -89,7 +92,6 @@ object AggregateUtil { mapFunction } - /** * Create a [[org.apache.flink.api.common.functions.MapFunction]] that prepares for aggregates. * The output of the function contains the grouping keys and the timestamp and the intermediate @@ -98,17 +100,16 @@ object AggregateUtil { * event-time, the timestamp is not aligned and used to sort. * * The output is stored in Row by the following format: - * * {{{ - * avg(x) aggOffsetInRow = 2 count(z) aggOffsetInRow = 5 - * | | - * v v - * +---------+---------+--------+--------+--------+--------+--------+ - * |groupKey1|groupKey2| sum1 | count1 | sum2 | count2 | rowtime| - * +---------+---------+--------+--------+--------+--------+--------+ - * ^ ^ - * | | - * sum(y) aggOffsetInRow = 4 rowtime to group or sort + * avg(x) count(z) + * | | + * v v + * +---------+---------+----------------+----------------+------------------+-------+ + * |groupKey1|groupKey2| AvgAccumulator | SumAccumulator | CountAccumulator |rowtime| + * +---------+---------+----------------+----------------+------------------+-------+ + * ^ ^ + * | | + * sum(y) rowtime to group or sort * }}} * * NOTE: this function is only used for time based window on batch tables. @@ -119,7 +120,7 @@ object AggregateUtil { groupings: Array[Int], inputType: RelDataType, isParserCaseSensitive: Boolean) - : MapFunction[Row, Row] = { + : MapFunction[Row, Row] = { val (aggFieldIndexes, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), @@ -127,7 +128,11 @@ object AggregateUtil { groupings.length) val mapReturnType: RowTypeInfo = - createAggregateBufferDataType(groupings, aggregates, inputType, Some(Array(Types.LONG))) + createDataSetAggregateBufferDataType( + groupings, + aggregates, + inputType, + Some(Array(Types.LONG))) val (timeFieldPos, tumbleTimeWindowSize) = window match { case EventTimeTumblingGroupWindow(_, time, size) => @@ -175,9 +180,6 @@ object AggregateUtil { inputType, groupings.length)._2 - val intermediateRowArity = groupings.length + - aggregates.map(_.intermediateDataType.length).sum - // the mapping relation between field index of intermediate aggregate Row and output Row. val groupingOffsetMapping = getGroupKeysMapping(inputType, outputType, groupings) @@ -196,30 +198,26 @@ object AggregateUtil { case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => // tumbling time window val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) - if (aggregates.forall(_.supportPartial)) { + if (doAllSupportPartialMerge(aggregates)) { // for incremental aggregations new DataSetTumbleTimeWindowAggReduceCombineFunction( - intermediateRowArity, asLong(size), startPos, endPos, aggregates, groupingOffsetMapping, aggOffsetMapping, - intermediateRowArity + 1, // the additional field is used to store the time attribute outputType.getFieldCount) } else { // for non-incremental aggregations new DataSetTumbleTimeWindowAggReduceGroupFunction( - intermediateRowArity, asLong(size), startPos, endPos, aggregates, groupingOffsetMapping, aggOffsetMapping, - intermediateRowArity + 1, // the additional field is used to store the time attribute outputType.getFieldCount) } case EventTimeTumblingGroupWindow(_, _, size) => @@ -229,7 +227,6 @@ object AggregateUtil { aggregates, groupingOffsetMapping, aggOffsetMapping, - intermediateRowArity + 1,// the additional field is used to store the time attribute outputType.getFieldCount) case EventTimeSessionGroupWindow(_, _, gap) => @@ -238,8 +235,6 @@ object AggregateUtil { aggregates, groupingOffsetMapping, aggOffsetMapping, - // the additional two fields are used to store window-start and window-end attributes - intermediateRowArity + 2, outputType.getFieldCount, startPos, endPos, @@ -255,19 +250,16 @@ object AggregateUtil { * for aggregates. * The function returns intermediate aggregate values of all aggregate function which are * organized by the following format: - * * {{{ - * avg(x) aggOffsetInRow = 2 count(z) aggOffsetInRow = 5 - * | | windowEnd(max(rowtime) - * | | | - * v v v - * +---------+---------+--------+--------+--------+--------+-----------+---------+ - * |groupKey1|groupKey2| sum1 | count1 | sum2 | count2 |windowStart|windowEnd| - * +---------+---------+--------+--------+--------+--------+-----------+---------+ - * ^ ^ - * | | - * sum(y) aggOffsetInRow = 4 windowStart(min(rowtime)) - * + * avg(x) windowEnd(max(rowtime) + * | | + * v v + * +---------+---------+----------------+----------------+-------------+-----------+ + * |groupKey1|groupKey2| AvgAccumulator | SumAccumulator | windowStart | windowEnd | + * +---------+---------+----------------+----------------+-------------+-----------+ + * ^ ^ + * | | + * sum(y) windowStart(min(rowtime)) * }}} * */ @@ -276,20 +268,17 @@ object AggregateUtil { namedAggregates: Seq[CalcitePair[AggregateCall, String]], inputType: RelDataType, groupings: Array[Int]) - : RichGroupCombineFunction[Row,Row] = { + : RichGroupCombineFunction[Row, Row] = { val aggregates = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, groupings.length)._2 - val intermediateRowArity = groupings.length + - aggregates.map(_.intermediateDataType.length).sum - window match { case EventTimeSessionGroupWindow(_, _, gap) => val combineReturnType: RowTypeInfo = - createAggregateBufferDataType( + createDataSetAggregateBufferDataType( groupings, aggregates, inputType, @@ -298,8 +287,6 @@ object AggregateUtil { new DataSetSessionWindowAggregateCombineGroupFunction( aggregates, groupings, - // the addition two fields are used to store window-start and window-end attributes - intermediateRowArity + 2, asLong(gap), combineReturnType) case _ => @@ -324,10 +311,10 @@ object AggregateUtil { inGroupingSet: Boolean) : RichGroupReduceFunction[Row, Row] = { - val aggregates = transformToAggregateFunctions( + val (aggFieldIndex, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - groupings.length)._2 + groupings.length) val (groupingOffsetMapping, aggOffsetMapping) = getGroupingOffsetAndAggOffsetMapping( @@ -342,19 +329,13 @@ object AggregateUtil { Array() } - val allPartialAggregate: Boolean = aggregates.forall(_.supportPartial) - - val intermediateRowArity = groupings.length + - aggregates.map(_.intermediateDataType.length).sum - val groupReduceFunction = - if (allPartialAggregate) { + if (doAllSupportPartialMerge(aggregates)) { new AggregateReduceCombineFunction( aggregates, groupingOffsetMapping, aggOffsetMapping, groupingSetsMapping, - intermediateRowArity, outputType.getFieldCount) } else { @@ -363,199 +344,109 @@ object AggregateUtil { groupingOffsetMapping, aggOffsetMapping, groupingSetsMapping, - intermediateRowArity, outputType.getFieldCount) } groupReduceFunction } /** - * Create a [[org.apache.flink.api.common.functions.ReduceFunction]] for incremental window - * aggregation. - * - */ - private[flink] def createIncrementalAggregateReduceFunction( - namedAggregates: Seq[CalcitePair[AggregateCall, String]], - inputType: RelDataType, - outputType: RelDataType, - groupings: Array[Int]) - : IncrementalAggregateReduceFunction = { - - val aggregates = transformToAggregateFunctions( - namedAggregates.map(_.getKey),inputType,groupings.length)._2 - - val groupingOffsetMapping = - getGroupingOffsetAndAggOffsetMapping( - namedAggregates, - inputType, - outputType, - groupings)._1 - - val intermediateRowArity = groupings.length + aggregates.map(_.intermediateDataType.length).sum - val reduceFunction = new IncrementalAggregateReduceFunction( - aggregates, - groupingOffsetMapping, - intermediateRowArity) - reduceFunction - } - - /** - * Create an [[AllWindowFunction]] to compute non-partitioned group window aggregates. + * Create an [[AllWindowFunction]] for non-partitioned window aggregates. */ - private[flink] def createAllWindowAggregationFunction( + private[flink] def createAggregationAllWindowFunction( window: LogicalWindow, - namedAggregates: Seq[CalcitePair[AggregateCall, String]], - inputType: RelDataType, - outputType: RelDataType, - groupings: Array[Int], + finalRowArity: Int, properties: Seq[NamedWindowProperty]) : AllWindowFunction[Row, Row, DataStreamWindow] = { - val aggFunction = - createAggregateGroupReduceFunction( - namedAggregates, - inputType, - outputType, - groupings, - inGroupingSet = false) - if (isTimeWindow(window)) { val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) - new AggregateAllTimeWindowFunction(aggFunction, startPos, endPos) - .asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]] + new IncrementalAggregateAllTimeWindowFunction( + startPos, + endPos, + finalRowArity) + .asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]] } else { - new AggregateAllWindowFunction(aggFunction) + new IncrementalAggregateAllWindowFunction( + finalRowArity) } } /** - * Create a [[WindowFunction]] to compute partitioned group window aggregates. - * + * Create a [[WindowFunction]] for group window aggregates. */ - private[flink] def createWindowAggregationFunction( + private[flink] def createAggregationGroupWindowFunction( window: LogicalWindow, - namedAggregates: Seq[CalcitePair[AggregateCall, String]], - inputType: RelDataType, - outputType: RelDataType, - groupings: Array[Int], + numGroupingKeys: Int, + numAggregates: Int, + finalRowArity: Int, properties: Seq[NamedWindowProperty]) : WindowFunction[Row, Row, Tuple, DataStreamWindow] = { - val aggFunction = - createAggregateGroupReduceFunction( - namedAggregates, - inputType, - outputType, - groupings, - inGroupingSet = false) - if (isTimeWindow(window)) { val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) - new AggregateTimeWindowFunction(aggFunction, startPos, endPos) - .asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]] + new IncrementalAggregateTimeWindowFunction( + numGroupingKeys, + numAggregates, + startPos, + endPos, + finalRowArity) + .asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]] } else { - new AggregateWindowFunction(aggFunction) + new IncrementalAggregateWindowFunction( + numGroupingKeys, + numAggregates, + finalRowArity) } } - /** - * Create an [[AllWindowFunction]] to finalize incrementally pre-computed non-partitioned - * window aggregates. - */ - private[flink] def createAllWindowIncrementalAggregationFunction( - window: LogicalWindow, + private[flink] def createDataStreamAggregateFunction( namedAggregates: Seq[CalcitePair[AggregateCall, String]], inputType: RelDataType, outputType: RelDataType, - groupings: Array[Int], - properties: Seq[NamedWindowProperty]) - : AllWindowFunction[Row, Row, DataStreamWindow] = { + groupKeysIndex: Array[Int]) + : (ApiAggregateFunction[Row, Row, Row], RowTypeInfo, RowTypeInfo) = { - val aggregates = transformToAggregateFunctions( - namedAggregates.map(_.getKey),inputType,groupings.length)._2 + val (aggFields, aggregates) = + transformToAggregateFunctions(namedAggregates.map(_.getKey), inputType, groupKeysIndex.length) - val (groupingOffsetMapping, aggOffsetMapping) = - getGroupingOffsetAndAggOffsetMapping( - namedAggregates, - inputType, - outputType, - groupings) - - val finalRowArity = outputType.getFieldCount + val aggregateMapping = getAggregateMapping(namedAggregates, outputType) - if (isTimeWindow(window)) { - val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) - new IncrementalAggregateAllTimeWindowFunction( - aggregates, - groupingOffsetMapping, - aggOffsetMapping, - finalRowArity, - startPos, - endPos) - .asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]] - } else { - new IncrementalAggregateAllWindowFunction( - aggregates, - groupingOffsetMapping, - aggOffsetMapping, - finalRowArity) + if (aggregateMapping.length != namedAggregates.length) { + throw new TableException( + "Could not find output field in input data type or aggregate functions.") } - } - /** - * Create a [[WindowFunction]] to finalize incrementally pre-computed window aggregates. - */ - private[flink] def createWindowIncrementalAggregationFunction( - window: LogicalWindow, - namedAggregates: Seq[CalcitePair[AggregateCall, String]], - inputType: RelDataType, - outputType: RelDataType, - groupings: Array[Int], - properties: Seq[NamedWindowProperty]) - : WindowFunction[Row, Row, Tuple, DataStreamWindow] = { + val aggResultTypes = namedAggregates.map(a => FlinkTypeFactory.toTypeInfo(a.left.getType)) - val aggregates = transformToAggregateFunctions( - namedAggregates.map(_.getKey),inputType,groupings.length)._2 + val accumulatorRowType = createAccumulatorRowType(inputType, aggregates) + val aggResultRowType = new RowTypeInfo(aggResultTypes: _*) + val aggFunction = new AggregateAggFunction(aggregates, aggFields) - val (groupingOffsetMapping, aggOffsetMapping) = - getGroupingOffsetAndAggOffsetMapping( - namedAggregates, - inputType, - outputType, - groupings) - - val finalRowArity = outputType.getFieldCount - - if (isTimeWindow(window)) { - val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) - new IncrementalAggregateTimeWindowFunction( - aggregates, - groupingOffsetMapping, - aggOffsetMapping, - finalRowArity, - startPos, - endPos) - .asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]] - } else { - new IncrementalAggregateWindowFunction( - aggregates, - groupingOffsetMapping, - aggOffsetMapping, - finalRowArity) - } + (aggFunction, accumulatorRowType, aggResultRowType) } /** - * Return true if all aggregates can be partially computed. False otherwise. + * Return true if all aggregates can be partially merged. False otherwise. */ - private[flink] def doAllSupportPartialAggregation( + private[flink] def doAllSupportPartialMerge( aggregateCalls: Seq[AggregateCall], inputType: RelDataType, groupKeysCount: Int): Boolean = { - transformToAggregateFunctions( + + val aggregateList = transformToAggregateFunctions( aggregateCalls, inputType, - groupKeysCount)._2.forall(_.supportPartial) + groupKeysCount)._2 + + doAllSupportPartialMerge(aggregateList) + } + + /** + * Return true if all aggregates can be partially merged. False otherwise. + */ + private[flink] def doAllSupportPartialMerge( + aggregateList: Array[TableAggregateFunction[_ <: Any]]): Boolean = { + aggregateList.forall(ifMethodExistInFunction("merge", _)) } /** @@ -601,10 +492,10 @@ object AggregateUtil { // map from field -> i$field or field -> i$field_0 val groupingFields = inputFields.map(inputFieldName => { - val base = "i$" + inputFieldName - var name = base - var i = 0 - while (inputFields.contains(name)) { + val base = "i$" + inputFieldName + var name = base + var i = 0 + while (inputFields.contains(name)) { name = base + "_" + i // if i$XXX is already a field it will be suffixed by _NUMBER i = i + 1 } @@ -642,7 +533,7 @@ object AggregateUtil { } private[flink] def computeWindowStartEndPropertyPos( - properties: Seq[NamedWindowProperty]): (Option[Int], Option[Int]) = { + properties: Seq[NamedWindowProperty]): (Option[Int], Option[Int]) = { val propPos = properties.foldRight((None: Option[Int], None: Option[Int], 0)) { (p, x) => p match { @@ -663,15 +554,15 @@ object AggregateUtil { } private def transformToAggregateFunctions( - aggregateCalls: Seq[AggregateCall], - inputType: RelDataType, - groupKeysCount: Int): (Array[Int], Array[Aggregate[_ <: Any]]) = { + aggregateCalls: Seq[AggregateCall], + inputType: RelDataType, + groupKeysCount: Int): (Array[Int], Array[TableAggregateFunction[_ <: Any]]) = { // store the aggregate fields of each aggregate function, by the same order of aggregates. val aggFieldIndexes = new Array[Int](aggregateCalls.size) - val aggregates = new Array[Aggregate[_ <: Any]](aggregateCalls.size) + val aggregates = new Array[TableAggregateFunction[_ <: Any]](aggregateCalls.size) - // set the start offset of aggregate buffer value to group keys' length, + // set the start offset of aggregate buffer value to group keys' length, // as all the group keys would be moved to the start fields of intermediate // aggregate data. var aggOffset = groupKeysCount @@ -696,19 +587,19 @@ object AggregateUtil { case _: SqlSumAggFunction | _: SqlSumEmptyIsZeroAggFunction => { aggregates(index) = sqlTypeName match { case TINYINT => - new ByteSumAggregate + new ByteSumAggFunction case SMALLINT => - new ShortSumAggregate + new ShortSumAggFunction case INTEGER => - new IntSumAggregate + new IntSumAggFunction case BIGINT => - new LongSumAggregate + new LongSumAggFunction case FLOAT => - new FloatSumAggregate + new FloatSumAggFunction case DOUBLE => - new DoubleSumAggregate + new DoubleSumAggFunction case DECIMAL => - new DecimalSumAggregate + new DecimalSumAggFunction case sqlType: SqlTypeName => throw new TableException("Sum aggregate does no support type:" + sqlType) } @@ -716,19 +607,19 @@ object AggregateUtil { case _: SqlAvgAggFunction => { aggregates(index) = sqlTypeName match { case TINYINT => - new ByteAvgAggregate + new ByteAvgAggFunction case SMALLINT => - new ShortAvgAggregate + new ShortAvgAggFunction case INTEGER => - new IntAvgAggregate + new IntAvgAggFunction case BIGINT => - new LongAvgAggregate + new LongAvgAggFunction case FLOAT => - new FloatAvgAggregate + new FloatAvgAggFunction case DOUBLE => - new DoubleAvgAggregate + new DoubleAvgAggFunction case DECIMAL => - new DecimalAvgAggregate + new DecimalAvgAggFunction case sqlType: SqlTypeName => throw new TableException("Avg aggregate does no support type:" + sqlType) } @@ -737,84 +628,114 @@ object AggregateUtil { aggregates(index) = if (sqlMinMaxFunction.getKind == SqlKind.MIN) { sqlTypeName match { case TINYINT => - new ByteMinAggregate + new ByteMinAggFunction case SMALLINT => - new ShortMinAggregate + new ShortMinAggFunction case INTEGER => - new IntMinAggregate + new IntMinAggFunction case BIGINT => - new LongMinAggregate + new LongMinAggFunction case FLOAT => - new FloatMinAggregate + new FloatMinAggFunction case DOUBLE => - new DoubleMinAggregate + new DoubleMinAggFunction case DECIMAL => - new DecimalMinAggregate + new DecimalMinAggFunction case BOOLEAN => - new BooleanMinAggregate + new BooleanMinAggFunction case sqlType: SqlTypeName => throw new TableException("Min aggregate does no support type:" + sqlType) } } else { sqlTypeName match { case TINYINT => - new ByteMaxAggregate + new ByteMaxAggFunction case SMALLINT => - new ShortMaxAggregate + new ShortMaxAggFunction case INTEGER => - new IntMaxAggregate + new IntMaxAggFunction case BIGINT => - new LongMaxAggregate + new LongMaxAggFunction case FLOAT => - new FloatMaxAggregate + new FloatMaxAggFunction case DOUBLE => - new DoubleMaxAggregate + new DoubleMaxAggFunction case DECIMAL => - new DecimalMaxAggregate + new DecimalMaxAggFunction case BOOLEAN => - new BooleanMaxAggregate + new BooleanMaxAggFunction case sqlType: SqlTypeName => throw new TableException("Max aggregate does no support type:" + sqlType) } } } case _: SqlCountAggFunction => - aggregates(index) = new CountAggregate + aggregates(index) = new CountAggFunction case unSupported: SqlAggFunction => throw new TableException("unsupported Function: " + unSupported.getName) } - setAggregateDataOffset(index) - } - - // set the aggregate intermediate data start index in Row, and update current value. - def setAggregateDataOffset(index: Int): Unit = { - aggregates(index).setAggOffsetInRow(aggOffset) - aggOffset += aggregates(index).intermediateDataType.length } (aggFieldIndexes, aggregates) } - private def createAggregateBufferDataType( - groupings: Array[Int], - aggregates: Array[Aggregate[_]], - inputType: RelDataType, - windowKeyTypes: Option[Array[TypeInformation[_]]] = None): RowTypeInfo = { + private def createAccumulatorType( + inputType: RelDataType, + aggregates: Array[TableAggregateFunction[_]]): Seq[TypeInformation[_]] = { + + val aggTypes: Seq[TypeInformation[_]] = + aggregates.map { + agg => + val accType = agg.getAccumulatorType() + if (accType != null) { + accType + } else { + val accumulator = agg.createAccumulator() + try { + TypeInformation.of(accumulator.getClass) + } catch { + case ite: InvalidTypesException => + throw new TableException( + "Cannot infer type of accumulator. " + + "You can override AggregateFunction.getAccumulatorType() to specify the type.", + ite) + } + } + } + + aggTypes + } + + private def createDataSetAggregateBufferDataType( + groupings: Array[Int], + aggregates: Array[TableAggregateFunction[_]], + inputType: RelDataType, + windowKeyTypes: Option[Array[TypeInformation[_]]] = None): RowTypeInfo = { // get the field data types of group keys. - val groupingTypes: Seq[TypeInformation[_]] = groupings - .map(inputType.getFieldList.get(_).getType) - .map(FlinkTypeFactory.toTypeInfo) + val groupingTypes: Seq[TypeInformation[_]] = + groupings + .map(inputType.getFieldList.get(_).getType) + .map(FlinkTypeFactory.toTypeInfo) // get all field data types of all intermediate aggregates - val aggTypes: Seq[TypeInformation[_]] = aggregates.flatMap(_.intermediateDataType) + val aggTypes: Seq[TypeInformation[_]] = createAccumulatorType(inputType, aggregates) // concat group key types, aggregation types, and window key types - val allFieldTypes:Seq[TypeInformation[_]] = windowKeyTypes match { + val allFieldTypes: Seq[TypeInformation[_]] = windowKeyTypes match { case None => groupingTypes ++: aggTypes case _ => groupingTypes ++: aggTypes ++: windowKeyTypes.get } - new RowTypeInfo(allFieldTypes :_*) + new RowTypeInfo(allFieldTypes: _*) + } + + private def createAccumulatorRowType( + inputType: RelDataType, + aggregates: Array[TableAggregateFunction[_]]): RowTypeInfo = { + + val aggTypes: Seq[TypeInformation[_]] = createAccumulatorType(inputType, aggregates) + + new RowTypeInfo(aggTypes: _*) } // Find the mapping between the index of aggregate list and aggregated value index in output Row. @@ -826,12 +747,12 @@ object AggregateUtil { // field index in output Row. var aggOffsetMapping = ArrayBuffer[(Int, Int)]() - outputType.getFieldList.zipWithIndex.foreach{ + outputType.getFieldList.zipWithIndex.foreach { case (outputFieldType, outputIndex) => namedAggregates.zipWithIndex.foreach { case (namedAggCall, aggregateIndex) => if (namedAggCall.getValue.equals(outputFieldType.getName) && - namedAggCall.getKey.getType.equals(outputFieldType.getType)) { + namedAggCall.getKey.getType.equals(outputFieldType.getType)) { aggOffsetMapping += ((outputIndex, aggregateIndex)) } } @@ -856,7 +777,7 @@ object AggregateUtil { // find the field index in input data type. case (inputFieldType, inputIndex) => if (outputFieldType.getName.equals(inputFieldType.getName) && - outputFieldType.getType.equals(inputFieldType.getType)) { + outputFieldType.getType.equals(inputFieldType.getType)) { // as aggregated field in output data type would not have a matched field in // input data, so if inputIndex is not -1, it must be a group key. Then we can // find the field index in buffer data by the group keys index mapping between @@ -906,6 +827,5 @@ object AggregateUtil { case Literal(value: Long, RowIntervalTypeInfo.INTERVAL_ROWS) => value case _ => throw new IllegalArgumentException() } - }
http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateWindowFunction.scala deleted file mode 100644 index 5491b1d..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateWindowFunction.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.runtime.aggregate - -import java.lang.Iterable - -import org.apache.flink.api.common.functions.RichGroupReduceFunction -import org.apache.flink.api.java.tuple.Tuple -import org.apache.flink.types.Row -import org.apache.flink.configuration.Configuration -import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction -import org.apache.flink.streaming.api.windowing.windows.Window -import org.apache.flink.util.Collector - -class AggregateWindowFunction[W <: Window](groupReduceFunction: RichGroupReduceFunction[Row, Row]) - extends RichWindowFunction[Row, Row, Tuple, W] { - - override def open(parameters: Configuration): Unit = { - groupReduceFunction.open(parameters) - } - - override def apply( - key: Tuple, - window: W, - input: Iterable[Row], - out: Collector[Row]): Unit = { - - groupReduceFunction.reduce(input, out) - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala index f1d91a3..47fa0f1 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateCombineGroupFunction.scala @@ -18,40 +18,44 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable +import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.RichGroupCombineFunction import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.types.Row import org.apache.flink.configuration.Configuration +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} import org.apache.flink.util.{Collector, Preconditions} /** * This wraps the aggregate logic inside of * [[org.apache.flink.api.java.operators.GroupCombineOperator]]. * - * @param aggregates The aggregate functions. - * @param groupingKeys The indexes of the grouping fields. - * @param intermediateRowArity The intermediate row field count. - * @param gap Session time window gap. + * @param aggregates The aggregate functions. + * @param groupingKeys The indexes of the grouping fields. + * @param gap Session time window gap. * @param intermediateRowType Intermediate row data type. */ class DataSetSessionWindowAggregateCombineGroupFunction( - aggregates: Array[Aggregate[_ <: Any]], + aggregates: Array[AggregateFunction[_ <: Any]], groupingKeys: Array[Int], - intermediateRowArity: Int, gap: Long, @transient intermediateRowType: TypeInformation[Row]) - extends RichGroupCombineFunction[Row,Row] with ResultTypeQueryable[Row] { + extends RichGroupCombineFunction[Row, Row] with ResultTypeQueryable[Row] { private var aggregateBuffer: Row = _ - private var rowTimeFieldPos = 0 + private var accumStartPos: Int = groupingKeys.length + private var rowTimeFieldPos = accumStartPos + aggregates.length + private val maxMergeLen = 16 + val accumulatorList = Array.fill(aggregates.length) { + new JArrayList[Accumulator]() + } override def open(config: Configuration) { Preconditions.checkNotNull(aggregates) Preconditions.checkNotNull(groupingKeys) - aggregateBuffer = new Row(intermediateRowArity) - rowTimeFieldPos = intermediateRowArity - 2 + aggregateBuffer = new Row(rowTimeFieldPos + 2) } /** @@ -59,7 +63,7 @@ class DataSetSessionWindowAggregateCombineGroupFunction( * (current'rowtime - previousârowtime > gap), and then merge data (within a unified window) * into an aggregate buffer. * - * @param records Sub-grouped intermediate aggregate Rows. + * @param records Sub-grouped intermediate aggregate Rows. * @return Combined intermediate aggregate Row. * */ @@ -68,10 +72,15 @@ class DataSetSessionWindowAggregateCombineGroupFunction( var windowStart: java.lang.Long = null var windowEnd: java.lang.Long = null var currentRowTime: java.lang.Long = null + accumulatorList.foreach(_.clear()) val iterator = records.iterator() + + + var count: Int = 0 while (iterator.hasNext) { val record = iterator.next() + count += 1 currentRowTime = record.getField(rowTimeFieldPos).asInstanceOf[Long] // initial traversal or opening a new window if (windowEnd == null || (windowEnd != null && (currentRowTime > windowEnd))) { @@ -79,7 +88,11 @@ class DataSetSessionWindowAggregateCombineGroupFunction( // calculate the current window and open a new window. if (windowEnd != null) { // emit the current window's merged data - doCollect(out, windowStart, windowEnd) + doCollect(out, accumulatorList, windowStart, windowEnd) + + // clear the accumulator list for all aggregate + accumulatorList.foreach(_.clear()) + count = 0 } else { // set group keys to aggregateBuffer. for (i <- groupingKeys.indices) { @@ -87,36 +100,59 @@ class DataSetSessionWindowAggregateCombineGroupFunction( } } - // initiate intermediate aggregate value. - aggregates.foreach(_.initiate(aggregateBuffer)) windowStart = record.getField(rowTimeFieldPos).asInstanceOf[Long] } - // merge intermediate aggregate value to the buffered value. - aggregates.foreach(_.merge(record, aggregateBuffer)) + // collect the accumulators for each aggregate + for (i <- aggregates.indices) { + accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator]) + } + + // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one + // accumulator + if (count > maxMergeLen) { + count = 0 + for (i <- aggregates.indices) { + val agg = aggregates(i) + val accumulator = agg.merge(accumulatorList(i)) + accumulatorList(i).clear() + accumulatorList(i).add(accumulator) + } + } // the current rowtime is the last rowtime of the next calculation. windowEnd = currentRowTime + gap } // emit the merged data of the current window. - doCollect(out, windowStart, windowEnd) + doCollect(out, accumulatorList, windowStart, windowEnd) } /** * Emit the merged data of the current window. - * @param windowStart the window's start attribute value is the min (rowtime) - * of all rows in the window. - * @param windowEnd the window's end property value is max (rowtime) + gap - * for all rows in the window. + * + * @param out the collection of the aggregate results + * @param accumulatorList an array (indexed by aggregate index) of the accumulator lists for + * each aggregate + * @param windowStart the window's start attribute value is the min (rowtime) + * of all rows in the window. + * @param windowEnd the window's end property value is max (rowtime) + gap + * for all rows in the window. */ def doCollect( - out: Collector[Row], - windowStart: Long, - windowEnd: Long): Unit = { + out: Collector[Row], + accumulatorList: Array[JArrayList[Accumulator]], + windowStart: Long, + windowEnd: Long): Unit = { + + // merge the accumulators into one accumulator + for (i <- aggregates.indices) { + aggregateBuffer.setField(accumStartPos + i, aggregates(i).merge(accumulatorList(i))) + } - // intermediate Row WindowStartPos is rowtime pos . + // intermediate Row WindowStartPos is rowtime pos. aggregateBuffer.setField(rowTimeFieldPos, windowStart) - // intermediate Row WindowEndPos is rowtime pos + 1 . + + // intermediate Row WindowEndPos is rowtime pos + 1. aggregateBuffer.setField(rowTimeFieldPos + 1, windowEnd) out.collect(aggregateBuffer) http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala index 99d241d..1570671 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala @@ -18,10 +18,12 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable +import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.types.Row import org.apache.flink.configuration.Configuration +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} import org.apache.flink.util.{Collector, Preconditions} /** @@ -30,49 +32,51 @@ import org.apache.flink.util.{Collector, Preconditions} * on batch. * * Note: - * - * This can handle two input types (depending if input is combined or not): + * + * This can handle two input types (depending if input is combined or not): * * 1. when partial aggregate is not supported, the input data structure of reduce is - * |groupKey1|groupKey2|sum1|count1|sum2|count2|rowTime| + * |groupKey1|groupKey2|sum1|count1|sum2|count2|rowTime| * 2. when partial aggregate is supported, the input data structure of reduce is - * |groupKey1|groupKey2|sum1|count1|sum2|count2|windowStart|windowEnd| + * |groupKey1|groupKey2|sum1|count1|sum2|count2|windowStart|windowEnd| * - * @param aggregates The aggregate functions. - * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row - * and output Row. - * @param aggregateMapping The index mapping between aggregate function list and aggregated value - * index in output Row. - * @param intermediateRowArity The intermediate row field count. - * @param finalRowArity The output row field count. + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and + * aggregated value index in output Row. + * @param finalRowArity The output row field count. * @param finalRowWindowStartPos The relative window-start field position. - * @param finalRowWindowEndPos The relative window-end field position. - * @param gap Session time window gap. + * @param finalRowWindowEndPos The relative window-end field position. + * @param gap Session time window gap. */ class DataSetSessionWindowAggregateReduceGroupFunction( - aggregates: Array[Aggregate[_ <: Any]], + aggregates: Array[AggregateFunction[_ <: Any]], groupKeysMapping: Array[(Int, Int)], aggregateMapping: Array[(Int, Int)], - intermediateRowArity: Int, finalRowArity: Int, finalRowWindowStartPos: Option[Int], finalRowWindowEndPos: Option[Int], - gap:Long, + gap: Long, isInputCombined: Boolean) extends RichGroupReduceFunction[Row, Row] { private var aggregateBuffer: Row = _ - private var intermediateRowWindowStartPos = 0 - private var intermediateRowWindowEndPos = 0 private var output: Row = _ private var collector: TimeWindowPropertyCollector = _ + private var accumStartPos: Int = groupKeysMapping.length + private var intermediateRowArity: Int = accumStartPos + aggregates.length + 2 + private var intermediateRowWindowStartPos = intermediateRowArity - 2 + private var intermediateRowWindowEndPos = intermediateRowArity - 1 + private val maxMergeLen = 16 + val accumulatorList = Array.fill(aggregates.length) { + new JArrayList[Accumulator]() + } override def open(config: Configuration) { Preconditions.checkNotNull(aggregates) Preconditions.checkNotNull(groupKeysMapping) aggregateBuffer = new Row(intermediateRowArity) - intermediateRowWindowStartPos = intermediateRowArity - 2 - intermediateRowWindowEndPos = intermediateRowArity - 1 output = new Row(finalRowArity) collector = new TimeWindowPropertyCollector(finalRowWindowStartPos, finalRowWindowEndPos) } @@ -91,11 +95,15 @@ class DataSetSessionWindowAggregateReduceGroupFunction( var windowStart: java.lang.Long = null var windowEnd: java.lang.Long = null - var currentRowTime:java.lang.Long = null + var currentRowTime: java.lang.Long = null + accumulatorList.foreach(_.clear()) val iterator = records.iterator() + + var count: Int = 0 while (iterator.hasNext) { val record = iterator.next() + count += 1 currentRowTime = record.getField(intermediateRowWindowStartPos).asInstanceOf[Long] // initial traversal or opening a new window if (null == windowEnd || @@ -104,7 +112,11 @@ class DataSetSessionWindowAggregateReduceGroupFunction( // calculate the current window and open a new window if (null != windowEnd) { // evaluate and emit the current window's result. - doEvaluateAndCollect(out, windowStart, windowEnd) + doEvaluateAndCollect(out, accumulatorList, windowStart, windowEnd) + + // clear the accumulator list for all aggregate + accumulatorList.foreach(_.clear()) + count = 0 } else { // set group keys value to final output. groupKeysMapping.foreach { @@ -112,13 +124,26 @@ class DataSetSessionWindowAggregateReduceGroupFunction( output.setField(after, record.getField(previous)) } } - // initiate intermediate aggregate value. - aggregates.foreach(_.initiate(aggregateBuffer)) + windowStart = record.getField(intermediateRowWindowStartPos).asInstanceOf[Long] } - // merge intermediate aggregate value to the buffered value. - aggregates.foreach(_.merge(record, aggregateBuffer)) + // collect the accumulators for each aggregate + for (i <- aggregates.indices) { + accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator]) + } + + // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one + // accumulator + if (count > maxMergeLen) { + count = 0 + for (i <- aggregates.indices) { + val agg = aggregates(i) + val accumulator = agg.merge(accumulatorList(i)) + accumulatorList(i).clear() + accumulatorList(i).add(accumulator) + } + } windowEnd = if (isInputCombined) { // partial aggregate is supported @@ -129,25 +154,32 @@ class DataSetSessionWindowAggregateReduceGroupFunction( } } // evaluate and emit the current window's result. - doEvaluateAndCollect(out, windowStart, windowEnd) + doEvaluateAndCollect(out, accumulatorList, windowStart, windowEnd) } /** * Evaluate and emit the data of the current window. - * @param windowStart the window's start attribute value is the min (rowtime) - * of all rows in the window. - * @param windowEnd the window's end property value is max (rowtime) + gap - * for all rows in the window. + * + * @param out the collection of the aggregate results + * @param accumulatorList an array (indexed by aggregate index) of the accumulator lists for + * each aggregate + * @param windowStart the window's start attribute value is the min (rowtime) of all rows + * in the window. + * @param windowEnd the window's end property value is max (rowtime) + gap for all rows + * in the window. */ def doEvaluateAndCollect( - out: Collector[Row], - windowStart: Long, - windowEnd: Long): Unit = { + out: Collector[Row], + accumulatorList: Array[JArrayList[Accumulator]], + windowStart: Long, + windowEnd: Long): Unit = { - // evaluate final aggregate value and set to output. + // merge the accumulators and then get value for the final output aggregateMapping.foreach { case (after, previous) => - output.setField(after, aggregates(previous).evaluate(aggregateBuffer)) + val agg = aggregates(previous) + val accum = agg.merge(accumulatorList(previous)) + output.setField(after, agg.getValue(accum)) } // adds TimeWindow properties to output then emit output http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala index 40dad17..b722330 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala @@ -18,9 +18,11 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable +import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.configuration.Configuration +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} import org.apache.flink.types.Row import org.apache.flink.util.{Collector, Preconditions} @@ -29,26 +31,30 @@ import org.apache.flink.util.{Collector, Preconditions} * [[org.apache.flink.api.java.operators.GroupReduceOperator]]. * It is only used for tumbling count-window on batch. * - * @param windowSize Tumble count window size - * @param aggregates The aggregate functions. + * @param windowSize Tumble count window size + * @param aggregates The aggregate functions. * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row * and output Row. * @param aggregateMapping The index mapping between aggregate function list and aggregated value * index in output Row. - * @param intermediateRowArity The intermediate row field count - * @param finalRowArity The output row field count + * @param finalRowArity The output row field count */ class DataSetTumbleCountWindowAggReduceGroupFunction( private val windowSize: Long, - private val aggregates: Array[Aggregate[_ <: Any]], + private val aggregates: Array[AggregateFunction[_ <: Any]], private val groupKeysMapping: Array[(Int, Int)], private val aggregateMapping: Array[(Int, Int)], - private val intermediateRowArity: Int, private val finalRowArity: Int) extends RichGroupReduceFunction[Row, Row] { private var aggregateBuffer: Row = _ private var output: Row = _ + private val accumStartPos: Int = groupKeysMapping.length + private val intermediateRowArity: Int = accumStartPos + aggregates.length + 1 + private val maxMergeLen = 16 + val accumulatorList = Array.fill(aggregates.length) { + new JArrayList[Accumulator]() + } override def open(config: Configuration) { Preconditions.checkNotNull(aggregates) @@ -60,30 +66,49 @@ class DataSetTumbleCountWindowAggReduceGroupFunction( override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { var count: Long = 0 + accumulatorList.foreach(_.clear()) val iterator = records.iterator() while (iterator.hasNext) { val record = iterator.next() + if (count == 0) { - // initiate intermediate aggregate value. - aggregates.foreach(_.initiate(aggregateBuffer)) + // clear the accumulator list for all aggregate + accumulatorList.foreach(_.clear()) } - // merge intermediate aggregate value to buffer. - aggregates.foreach(_.merge(record, aggregateBuffer)) + // collect the accumulators for each aggregate + for (i <- aggregates.indices) { + accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator]) + } count += 1 + + // for every maxMergeLen accumulators, we merge them into one + if (count % maxMergeLen == 0) { + for (i <- aggregates.indices) { + val agg = aggregates(i) + val accumulator = agg.merge(accumulatorList(i)) + accumulatorList(i).clear() + accumulatorList(i).add(accumulator) + } + } + if (windowSize == count) { // set group keys value to final output. groupKeysMapping.foreach { case (after, previous) => output.setField(after, record.getField(previous)) } - // evaluate final aggregate value and set to output. + + // merge the accumulators and then get value for the final output aggregateMapping.foreach { case (after, previous) => - output.setField(after, aggregates(previous).evaluate(aggregateBuffer)) + val agg = aggregates(previous) + val accumulator = agg.merge(accumulatorList(previous)) + output.setField(after, agg.getValue(accumulator)) } + // emit the output out.collect(output) count = 0 http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala index a72c9ca..d507a58 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala @@ -18,8 +18,10 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable +import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.CombineFunction +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} import org.apache.flink.types.Row /** @@ -28,68 +30,86 @@ import org.apache.flink.types.Row * [[org.apache.flink.api.java.operators.GroupCombineOperator]]. * It is used for tumbling time-window on batch. * - * @param rowtimePos The rowtime field index in input row - * @param windowSize Tumbling time window size - * @param windowStartPos The relative window-start field position to the last field of output row - * @param windowEndPos The relative window-end field position to the last field of output row - * @param aggregates The aggregate functions. + * @param windowSize Tumbling time window size + * @param windowStartPos The relative window-start field position to the last field of output row + * @param windowEndPos The relative window-end field position to the last field of output row + * @param aggregates The aggregate functions. * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row * and output Row. * @param aggregateMapping The index mapping between aggregate function list and aggregated value * index in output Row. - * @param intermediateRowArity The intermediate row field count - * @param finalRowArity The output row field count + * @param finalRowArity The output row field count */ class DataSetTumbleTimeWindowAggReduceCombineFunction( - rowtimePos: Int, windowSize: Long, windowStartPos: Option[Int], windowEndPos: Option[Int], - aggregates: Array[Aggregate[_ <: Any]], + aggregates: Array[AggregateFunction[_ <: Any]], groupKeysMapping: Array[(Int, Int)], aggregateMapping: Array[(Int, Int)], - intermediateRowArity: Int, finalRowArity: Int) extends DataSetTumbleTimeWindowAggReduceGroupFunction( - rowtimePos, windowSize, windowStartPos, windowEndPos, aggregates, groupKeysMapping, aggregateMapping, - intermediateRowArity, finalRowArity) - with CombineFunction[Row, Row] { + with CombineFunction[Row, Row] { /** * For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer, * - * @param records Sub-grouped intermediate aggregate Rows iterator. + * @param records Sub-grouped intermediate aggregate Rows iterator. * @return Combined intermediate aggregate Row. * */ override def combine(records: Iterable[Row]): Row = { - // initiate intermediate aggregate value. - aggregates.foreach(_.initiate(aggregateBuffer)) - - // merge intermediate aggregate value to buffer. var last: Row = null + accumulatorList.foreach(_.clear()) val iterator = records.iterator() + + var count: Int = 0 while (iterator.hasNext) { val record = iterator.next() - aggregates.foreach(_.merge(record, aggregateBuffer)) + count += 1 + // per each aggregator, collect its accumulators to a list + for (i <- aggregates.indices) { + accumulatorList(i).add(record.getField(groupKeysMapping.length + i) + .asInstanceOf[Accumulator]) + } + // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one + // accumulator + if (count > maxMergeLen) { + count = 0 + for (i <- aggregates.indices) { + val agg = aggregates(i) + val accumulator = agg.merge(accumulatorList(i)) + accumulatorList(i).clear() + accumulatorList(i).add(accumulator) + } + } last = record } + // per each aggregator, merge list of accumulators into one and save the result to the + // intermediate aggregate buffer + for (i <- aggregates.indices) { + val agg = aggregates(i) + aggregateBuffer.setField(groupKeysMapping.length + i, agg.merge(accumulatorList(i))) + } + // set group keys to aggregateBuffer. for (i <- groupKeysMapping.indices) { aggregateBuffer.setField(i, last.getField(i)) } // set the rowtime attribute + val rowtimePos = groupKeysMapping.length + aggregates.length + aggregateBuffer.setField(rowtimePos, last.getField(rowtimePos)) aggregateBuffer http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala index a4c03b9..63d2aeb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala @@ -18,10 +18,11 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable +import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.configuration.Configuration -import org.apache.flink.streaming.api.windowing.windows.TimeWindow +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} import org.apache.flink.types.Row import org.apache.flink.util.{Collector, Preconditions} @@ -30,33 +31,36 @@ import org.apache.flink.util.{Collector, Preconditions} * [[org.apache.flink.api.java.operators.GroupReduceOperator]]. It is used for tumbling time-window * on batch. * - * @param rowtimePos The rowtime field index in input row - * @param windowSize Tumbling time window size - * @param windowStartPos The relative window-start field position to the last field of output row - * @param windowEndPos The relative window-end field position to the last field of output row - * @param aggregates The aggregate functions. + * @param windowSize Tumbling time window size + * @param windowStartPos The relative window-start field position to the last field of output row + * @param windowEndPos The relative window-end field position to the last field of output row + * @param aggregates The aggregate functions. * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row * and output Row. * @param aggregateMapping The index mapping between aggregate function list and aggregated value * index in output Row. - * @param intermediateRowArity The intermediate row field count - * @param finalRowArity The output row field count + * @param finalRowArity The output row field count */ class DataSetTumbleTimeWindowAggReduceGroupFunction( - rowtimePos: Int, windowSize: Long, windowStartPos: Option[Int], windowEndPos: Option[Int], - aggregates: Array[Aggregate[_ <: Any]], + aggregates: Array[AggregateFunction[_ <: Any]], groupKeysMapping: Array[(Int, Int)], aggregateMapping: Array[(Int, Int)], - intermediateRowArity: Int, finalRowArity: Int) extends RichGroupReduceFunction[Row, Row] { private var collector: TimeWindowPropertyCollector = _ protected var aggregateBuffer: Row = _ private var output: Row = _ + private val accumStartPos: Int = groupKeysMapping.length + private val rowtimePos: Int = accumStartPos + aggregates.length + private val intermediateRowArity: Int = rowtimePos + 1 + protected val maxMergeLen = 16 + val accumulatorList = Array.fill(aggregates.length) { + new JArrayList[Accumulator]() + } override def open(config: Configuration) { Preconditions.checkNotNull(aggregates) @@ -68,16 +72,30 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction( override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { - // initiate intermediate aggregate value. - aggregates.foreach(_.initiate(aggregateBuffer)) - - // merge intermediate aggregate value to buffer. var last: Row = null + accumulatorList.foreach(_.clear()) val iterator = records.iterator() + + var count: Int = 0 while (iterator.hasNext) { val record = iterator.next() - aggregates.foreach(_.merge(record, aggregateBuffer)) + count += 1 + // per each aggregator, collect its accumulators to a list + for (i <- aggregates.indices) { + accumulatorList(i).add(record.getField(accumStartPos + i).asInstanceOf[Accumulator]) + } + // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one + // accumulator + if (count > maxMergeLen) { + count = 0 + for (i <- aggregates.indices) { + val agg = aggregates(i) + val accumulator = agg.merge(accumulatorList(i)) + accumulatorList(i).clear() + accumulatorList(i).add(accumulator) + } + } last = record } @@ -87,10 +105,14 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction( output.setField(after, last.getField(previous)) } - // evaluate final aggregate value and set to output. + // get final aggregate value and set to output. aggregateMapping.foreach { - case (after, previous) => - output.setField(after, aggregates(previous).evaluate(aggregateBuffer)) + case (after, previous) => { + val agg = aggregates(previous) + val accumulator = agg.merge(accumulatorList(previous)) + val result = agg.getValue(accumulator) + output.setField(after, result) + } } // get window start timestamp http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggregateMapFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggregateMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggregateMapFunction.scala index 5c3d374..68088fc 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggregateMapFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggregateMapFunction.scala @@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.windowing.windows.TimeWindow +import org.apache.flink.table.functions.AggregateFunction import org.apache.flink.types.Row import org.apache.flink.util.Preconditions @@ -34,13 +35,13 @@ import org.apache.flink.util.Preconditions * append an (aligned) rowtime field to the end of the output row. */ class DataSetWindowAggregateMapFunction( - private val aggregates: Array[Aggregate[_]], + private val aggregates: Array[AggregateFunction[_]], private val aggFields: Array[Int], private val groupingKeys: Array[Int], - private val timeFieldPos: Int, // time field position in input row + private val timeFieldPos: Int, // time field position in input row private val tumbleTimeWindowSize: Option[Long], @transient private val returnType: TypeInformation[Row]) - extends RichMapFunction[Row, Row] with ResultTypeQueryable[Row] { + extends RichMapFunction[Row, Row] with ResultTypeQueryable[Row] { private var output: Row = _ // rowtime index in the buffer output row @@ -51,18 +52,22 @@ class DataSetWindowAggregateMapFunction( Preconditions.checkNotNull(aggFields) Preconditions.checkArgument(aggregates.length == aggFields.length) // add one more arity to store rowtime - val partialRowLength = groupingKeys.length + - aggregates.map(_.intermediateDataType.length).sum + 1 + val partialRowLength = groupingKeys.length + aggregates.length + 1 // set rowtime to the last field of the output row rowtimeIndex = partialRowLength - 1 output = new Row(partialRowLength) } override def map(input: Row): Row = { + for (i <- aggregates.indices) { + val agg = aggregates(i) val fieldValue = input.getField(aggFields(i)) - aggregates(i).prepare(fieldValue, output) + val accumulator = agg.createAccumulator() + agg.accumulate(accumulator, fieldValue) + output.setField(groupingKeys.length + i, accumulator) } + for (i <- groupingKeys.indices) { output.setField(i, input.getField(groupingKeys(i))) } @@ -103,3 +108,4 @@ class DataSetWindowAggregateMapFunction( returnType } } + http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala index ed49dc3..51c614d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala @@ -23,28 +23,20 @@ import org.apache.flink.types.Row import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.windowing.windows.TimeWindow import org.apache.flink.util.Collector + /** * * Computes the final aggregate value from incrementally computed aggreagtes. * - * @param aggregates The aggregate functions. - * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row - * and output Row. - * @param aggregateMapping The index mapping between aggregate function list and aggregated value - * index in output Row. + * @param windowStartPos the start position of window + * @param windowEndPos the end position of window * @param finalRowArity The arity of the final output row. */ class IncrementalAggregateAllTimeWindowFunction( - private val aggregates: Array[Aggregate[_ <: Any]], - private val groupKeysMapping: Array[(Int, Int)], - private val aggregateMapping: Array[(Int, Int)], - private val finalRowArity: Int, private val windowStartPos: Option[Int], - private val windowEndPos: Option[Int]) + private val windowEndPos: Option[Int], + private val finalRowArity: Int) extends IncrementalAggregateAllWindowFunction[TimeWindow]( - aggregates, - groupKeysMapping, - aggregateMapping, finalRowArity) { private var collector: TimeWindowPropertyCollector = _ @@ -55,9 +47,9 @@ class IncrementalAggregateAllTimeWindowFunction( } override def apply( - window: TimeWindow, - records: Iterable[Row], - out: Collector[Row]): Unit = { + window: TimeWindow, + records: Iterable[Row], + out: Collector[Row]): Unit = { // set collector and window collector.wrappedCollector = out http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala index 3c41a62..00aba1f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllWindowFunction.scala @@ -28,25 +28,15 @@ import org.apache.flink.util.{Collector, Preconditions} /** * Computes the final aggregate value from incrementally computed aggreagtes. * - * @param aggregates The aggregate functions. - * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row - * and output Row. - * @param aggregateMapping The index mapping between aggregate function list and aggregated value - * index in output Row. - * @param finalRowArity The arity of the final output row. + * @param finalRowArity The arity of the final output row. */ class IncrementalAggregateAllWindowFunction[W <: Window]( - private val aggregates: Array[Aggregate[_ <: Any]], - private val groupKeysMapping: Array[(Int, Int)], - private val aggregateMapping: Array[(Int, Int)], private val finalRowArity: Int) extends RichAllWindowFunction[Row, Row, W] { private var output: Row = _ override def open(parameters: Configuration): Unit = { - Preconditions.checkNotNull(aggregates) - Preconditions.checkNotNull(groupKeysMapping) output = new Row(finalRowArity) } @@ -55,25 +45,15 @@ class IncrementalAggregateAllWindowFunction[W <: Window]( * Row based on the mapping relation between intermediate aggregate data and output data. */ override def apply( - window: W, - records: Iterable[Row], - out: Collector[Row]): Unit = { + window: W, + records: Iterable[Row], + out: Collector[Row]): Unit = { val iterator = records.iterator if (iterator.hasNext) { val record = iterator.next() - // Set group keys value to final output. - groupKeysMapping.foreach { - case (after, previous) => - output.setField(after, record.getField(previous)) - } - // Evaluate final aggregate value and set to output. - aggregateMapping.foreach { - case (after, previous) => - output.setField(after, aggregates(previous).evaluate(record)) - } - out.collect(output) + out.collect(record) } } } http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala deleted file mode 100644 index 14b44e8..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateReduceFunction.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.runtime.aggregate - -import org.apache.flink.api.common.functions.ReduceFunction -import org.apache.flink.types.Row -import org.apache.flink.util.Preconditions - -/** - * Incrementally computes group window aggregates. - * - * @param aggregates The aggregate functions. - * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row - * and output Row. - */ -class IncrementalAggregateReduceFunction( - private val aggregates: Array[Aggregate[_]], - private val groupKeysMapping: Array[(Int, Int)], - private val intermediateRowArity: Int) - extends ReduceFunction[Row] { - - Preconditions.checkNotNull(aggregates) - Preconditions.checkNotNull(groupKeysMapping) - - /** - * For Incremental intermediate aggregate Rows, merge value1 and value2 - * into aggregate buffer, return aggregate buffer. - * - * @param value1 The first value to combined. - * @param value2 The second value to combined. - * @return accumulatorRow A resulting row that combines two input values. - * - */ - override def reduce(value1: Row, value2: Row): Row = { - - // TODO: once FLINK-5105 is solved, we can avoid creating a new row for each invocation - // and directly merge value1 and value2. - val accumulatorRow = new Row(intermediateRowArity) - - // copy all fields of value1 into accumulatorRow - (0 until intermediateRowArity) - .foreach(i => accumulatorRow.setField(i, value1.getField(i))) - // merge value2 to accumulatorRow - aggregates.foreach(_.merge(value2, accumulatorRow)) - - accumulatorRow - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala index a6626d9..dccb4f6 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala @@ -28,24 +28,20 @@ import org.apache.flink.util.Collector /** * Computes the final aggregate value from incrementally computed aggreagtes. * - * @param aggregates The aggregate functions. - * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row - * and output Row. - * @param aggregateMapping The index mapping between aggregate function list and aggregated value - * index in output Row. - * @param finalRowArity The arity of the final output row. + * @param windowStartPos the start position of window + * @param windowEndPos the end position of window + * @param finalRowArity The arity of the final output row */ class IncrementalAggregateTimeWindowFunction( - private val aggregates: Array[Aggregate[_ <: Any]], - private val groupKeysMapping: Array[(Int, Int)], - private val aggregateMapping: Array[(Int, Int)], - private val finalRowArity: Int, + private val numGroupingKey: Int, + private val numAggregates: Int, private val windowStartPos: Option[Int], - private val windowEndPos: Option[Int]) + private val windowEndPos: Option[Int], + private val finalRowArity: Int) extends IncrementalAggregateWindowFunction[TimeWindow]( - aggregates, - groupKeysMapping, - aggregateMapping, finalRowArity) { + numGroupingKey, + numAggregates, + finalRowArity) { private var collector: TimeWindowPropertyCollector = _ @@ -55,10 +51,10 @@ class IncrementalAggregateTimeWindowFunction( } override def apply( - key: Tuple, - window: TimeWindow, - records: Iterable[Row], - out: Collector[Row]): Unit = { + key: Tuple, + window: TimeWindow, + records: Iterable[Row], + out: Collector[Row]): Unit = { // set collector and window collector.wrappedCollector = out http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala index 30f7a7b..a4d4837 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateWindowFunction.scala @@ -24,30 +24,24 @@ import org.apache.flink.types.Row import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction import org.apache.flink.streaming.api.windowing.windows.Window -import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.util.Collector /** * Computes the final aggregate value from incrementally computed aggreagtes. * - * @param aggregates The aggregate functions. - * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row - * and output Row. - * @param aggregateMapping The index mapping between aggregate function list and aggregated value - * index in output Row. - * @param finalRowArity The arity of the final output row. + * @param numGroupingKey The number of grouping keys. + * @param numAggregates The number of aggregates. + * @param finalRowArity The arity of the final output row. */ class IncrementalAggregateWindowFunction[W <: Window]( - private val aggregates: Array[Aggregate[_ <: Any]], - private val groupKeysMapping: Array[(Int, Int)], - private val aggregateMapping: Array[(Int, Int)], + private val numGroupingKey: Int, + private val numAggregates: Int, private val finalRowArity: Int) extends RichWindowFunction[Row, Row, Tuple, W] { private var output: Row = _ override def open(parameters: Configuration): Unit = { - Preconditions.checkNotNull(aggregates) - Preconditions.checkNotNull(groupKeysMapping) output = new Row(finalRowArity) } @@ -56,25 +50,23 @@ class IncrementalAggregateWindowFunction[W <: Window]( * Row based on the mapping relation between intermediate aggregate data and output data. */ override def apply( - key: Tuple, - window: W, - records: Iterable[Row], - out: Collector[Row]): Unit = { + key: Tuple, + window: W, + records: Iterable[Row], + out: Collector[Row]): Unit = { val iterator = records.iterator if (iterator.hasNext) { val record = iterator.next() - // Set group keys value to final output. - groupKeysMapping.foreach { - case (after, previous) => - output.setField(after, record.getField(previous)) + + for (i <- 0 until numGroupingKey) { + output.setField(i, key.getField(i)) } - // Evaluate final aggregate value and set to output. - aggregateMapping.foreach { - case (after, previous) => - output.setField(after, aggregates(previous).evaluate(record)) + for (i <- 0 until numAggregates) { + output.setField(numGroupingKey + i, record.getField(i)) } + out.collect(output) } } http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala index a243db7..818cd0e 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala @@ -133,17 +133,17 @@ class AggregationsITCase extends StreamingMultipleProgramsTestBase { val windowedTable = table .window(Tumble over 5.milli on 'rowtime as 'w) .groupBy('w, 'string) - .select('string, 'int.count, 'int.avg, 'w.start, 'w.end) + .select('string, 'int.count, 'int.avg, 'int.min, 'int.max, 'int.sum, 'w.start, 'w.end) val results = windowedTable.toDataStream[Row] results.addSink(new StreamITCase.StringSink) env.execute() val expected = Seq( - "Hello world,1,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01", - "Hello world,1,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02", - "Hello,2,2,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005", - "Hi,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005") + "Hello world,1,3,3,3,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01", + "Hello world,1,3,3,3,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02", + "Hello,2,2,2,2,4,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005", + "Hi,1,1,1,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005") assertEquals(expected.sorted, StreamITCase.testResults.sorted) }
