[FLINK-5768] [table] Refactor DataSet and DataStream aggregations to use UDAGG interface.
- DataStream aggregates use new WindowedStream.aggregate() operator. This closes #3423. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/438276de Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/438276de Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/438276de Branch: refs/heads/master Commit: 438276de8fab4f1a8f2b62b6452c2e5b2998ce5a Parents: 7fe0eb4 Author: shaoxuan-wang <[email protected]> Authored: Mon Feb 27 19:09:30 2017 +0800 Committer: Fabian Hueske <[email protected]> Committed: Thu Mar 2 21:31:18 2017 +0100 ---------------------------------------------------------------------- .../table/functions/AggregateFunction.scala | 13 +- .../functions/aggfunctions/AvgAggFunction.scala | 91 ++-- .../aggfunctions/CountAggFunction.scala | 19 +- .../functions/aggfunctions/MaxAggFunction.scala | 66 ++- .../functions/aggfunctions/MinAggFunction.scala | 66 ++- .../functions/aggfunctions/SumAggFunction.scala | 76 ++- .../utils/UserDefinedFunctionUtils.scala | 4 +- .../plan/nodes/dataset/DataSetAggregate.scala | 25 +- .../nodes/dataset/DataSetWindowAggregate.scala | 28 +- .../nodes/datastream/DataStreamAggregate.scala | 152 ++---- .../aggregate/AggregateAggFunction.scala | 79 ++++ .../AggregateAllTimeWindowFunction.scala | 52 --- .../aggregate/AggregateAllWindowFunction.scala | 41 -- .../aggregate/AggregateMapFunction.scala | 22 +- .../AggregateReduceCombineFunction.scala | 89 ++-- .../AggregateReduceGroupFunction.scala | 96 ++-- .../aggregate/AggregateTimeWindowFunction.scala | 57 --- .../table/runtime/aggregate/AggregateUtil.scala | 464 ++++++++----------- .../aggregate/AggregateWindowFunction.scala | 46 -- ...ionWindowAggregateCombineGroupFunction.scala | 88 ++-- ...sionWindowAggregateReduceGroupFunction.scala | 104 +++-- ...umbleCountWindowAggReduceGroupFunction.scala | 49 +- ...mbleTimeWindowAggReduceCombineFunction.scala | 58 ++- ...TumbleTimeWindowAggReduceGroupFunction.scala | 60 ++- .../DataSetWindowAggregateMapFunction.scala | 18 +- ...rementalAggregateAllTimeWindowFunction.scala | 24 +- .../IncrementalAggregateAllWindowFunction.scala | 30 +- .../IncrementalAggregateReduceFunction.scala | 63 --- ...IncrementalAggregateTimeWindowFunction.scala | 32 +- .../IncrementalAggregateWindowFunction.scala | 40 +- .../scala/stream/table/AggregationsITCase.scala | 10 +- .../aggfunctions/AggFunctionTestBase.scala | 2 +- .../dataset/DataSetWindowAggregateITCase.scala | 52 ++- 33 files changed, 1050 insertions(+), 1066 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala index e15a8c4..178b439 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala @@ -19,12 +19,14 @@ package org.apache.flink.table.functions import java.util.{List => JList} +import org.apache.flink.api.common.typeinfo.TypeInformation + /** * Base class for User-Defined Aggregates. * * @tparam T the type of the aggregation result */ -trait AggregateFunction[T] extends UserDefinedFunction { +abstract class AggregateFunction[T] extends UserDefinedFunction { /** * Create and init the Accumulator for this [[AggregateFunction]]. * @@ -61,6 +63,15 @@ trait AggregateFunction[T] extends UserDefinedFunction { * @return the resulting accumulator */ def merge(accumulators: JList[Accumulator]): Accumulator + + /** + * Returns the [[TypeInformation]] of the accumulator. + * This function is optional and can be implemented if the accumulator type cannot automatically + * inferred from the instance returned by [[createAccumulator()]]. + * + * @return The type information for the accumulator. + */ + def getAccumulatorType(): TypeInformation[_] = null } /** http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala index f4c0b7b..534bb03 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala @@ -19,9 +19,18 @@ package org.apache.flink.table.functions.aggfunctions import java.math.{BigDecimal, BigInteger} import java.util.{List => JList} + +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} +import org.apache.flink.api.java.typeutils.TupleTypeInfo import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +/** The initial accumulator for Integral Avg aggregate function */ +class IntegralAvgAccumulator extends JTuple2[Long, Long] with Accumulator { + f0 = 0L //sum + f1 = 0L //count +} + /** * Base class for built-in Integral Avg aggregate function * @@ -29,12 +38,6 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction} */ abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T] { - /** The initial accumulator for Integral Avg aggregate function */ - class IntegralAvgAccumulator extends JTuple2[Long, Long] with Accumulator { - f0 = 0 //sum - f1 = 0 //count - } - override def createAccumulator(): Accumulator = { new IntegralAvgAccumulator } @@ -44,7 +47,7 @@ abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T] { val v = value.asInstanceOf[Number].longValue() val accum = accumulator.asInstanceOf[IntegralAvgAccumulator] accum.f0 += v - accum.f1 += 1 + accum.f1 += 1L } } @@ -69,6 +72,13 @@ abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T] { ret } + override def getAccumulatorType(): TypeInformation[_] = { + new TupleTypeInfo( + new IntegralAvgAccumulator().getClass, + BasicTypeInfo.LONG_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO) + } + /** * Convert the intermediate result to the expected aggregation result type * @@ -100,6 +110,13 @@ class IntAvgAggFunction extends IntegralAvgAggFunction[Int] { override def resultTypeConvert(value: Long): Int = value.toInt } +/** The initial accumulator for Big Integral Avg aggregate function */ +class BigIntegralAvgAccumulator + extends JTuple2[BigInteger, Long] with Accumulator { + f0 = BigInteger.ZERO //sum + f1 = 0L //count +} + /** * Base Class for Built-in Big Integral Avg aggregate function * @@ -107,13 +124,6 @@ class IntAvgAggFunction extends IntegralAvgAggFunction[Int] { */ abstract class BigIntegralAvgAggFunction[T] extends AggregateFunction[T] { - /** The initial accumulator for Big Integral Avg aggregate function */ - class BigIntegralAvgAccumulator - extends JTuple2[BigInteger, Long] with Accumulator { - f0 = BigInteger.ZERO //sum - f1 = 0 //count - } - override def createAccumulator(): Accumulator = { new BigIntegralAvgAccumulator } @@ -123,7 +133,7 @@ abstract class BigIntegralAvgAggFunction[T] extends AggregateFunction[T] { val v = value.asInstanceOf[Long] val a = accumulator.asInstanceOf[BigIntegralAvgAccumulator] a.f0 = a.f0.add(BigInteger.valueOf(v)) - a.f1 += 1 + a.f1 += 1L } } @@ -148,6 +158,13 @@ abstract class BigIntegralAvgAggFunction[T] extends AggregateFunction[T] { ret } + override def getAccumulatorType(): TypeInformation[_] = { + new TupleTypeInfo( + new BigIntegralAvgAccumulator().getClass, + BasicTypeInfo.BIG_INT_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO) + } + /** * Convert the intermediate result to the expected aggregation result type * @@ -166,6 +183,12 @@ class LongAvgAggFunction extends BigIntegralAvgAggFunction[Long] { override def resultTypeConvert(value: BigInteger): Long = value.longValue() } +/** The initial accumulator for Floating Avg aggregate function */ +class FloatingAvgAccumulator extends JTuple2[Double, Long] with Accumulator { + f0 = 0 //sum + f1 = 0L //count +} + /** * Base class for built-in Floating Avg aggregate function * @@ -173,12 +196,6 @@ class LongAvgAggFunction extends BigIntegralAvgAggFunction[Long] { */ abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T] { - /** The initial accumulator for Floating Avg aggregate function */ - class FloatingAvgAccumulator extends JTuple2[Double, Long] with Accumulator { - f0 = 0 //sum - f1 = 0 //count - } - override def createAccumulator(): Accumulator = { new FloatingAvgAccumulator } @@ -188,7 +205,7 @@ abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T] { val v = value.asInstanceOf[Number].doubleValue() val accum = accumulator.asInstanceOf[FloatingAvgAccumulator] accum.f0 += v - accum.f1 += 1 + accum.f1 += 1L } } @@ -213,6 +230,13 @@ abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T] { ret } + override def getAccumulatorType(): TypeInformation[_] = { + new TupleTypeInfo( + new FloatingAvgAccumulator().getClass, + BasicTypeInfo.DOUBLE_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO) + } + /** * Convert the intermediate result to the expected aggregation result type * @@ -237,18 +261,18 @@ class DoubleAvgAggFunction extends FloatingAvgAggFunction[Double] { override def resultTypeConvert(value: Double): Double = value } +/** The initial accumulator for Big Decimal Avg aggregate function */ +class DecimalAvgAccumulator + extends JTuple2[BigDecimal, Long] with Accumulator { + f0 = BigDecimal.ZERO //sum + f1 = 0L //count +} + /** * Base class for built-in Big Decimal Avg aggregate function */ class DecimalAvgAggFunction extends AggregateFunction[BigDecimal] { - /** The initial accumulator for Big Decimal Avg aggregate function */ - class DecimalAvgAccumulator - extends JTuple2[BigDecimal, Long] with Accumulator { - f0 = BigDecimal.ZERO //sum - f1 = 0 //count - } - override def createAccumulator(): Accumulator = { new DecimalAvgAccumulator } @@ -262,7 +286,7 @@ class DecimalAvgAggFunction extends AggregateFunction[BigDecimal] { } else { accum.f0 = accum.f0.add(v) } - accum.f1 += 1 + accum.f1 += 1L } } @@ -286,4 +310,11 @@ class DecimalAvgAggFunction extends AggregateFunction[BigDecimal] { } ret } + + override def getAccumulatorType(): TypeInformation[_] = { + new TupleTypeInfo( + new DecimalAvgAccumulator().getClass, + BasicTypeInfo.BIG_DEC_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala index 8b903d1..cf884ed 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala @@ -18,22 +18,25 @@ package org.apache.flink.table.functions.aggfunctions import java.util.{List => JList} + +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1} +import org.apache.flink.api.java.typeutils.TupleTypeInfo import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +/** The initial accumulator for count aggregate function */ +class CountAccumulator extends JTuple1[Long] with Accumulator { + f0 = 0L //count +} + /** * built-in count aggregate function */ class CountAggFunction extends AggregateFunction[Long] { - /** The initial accumulator for count aggregate function */ - class CountAccumulator extends JTuple1[Long] with Accumulator { - f0 = 0 //count - } - override def accumulate(accumulator: Accumulator, value: Any): Unit = { if (value != null) { - accumulator.asInstanceOf[CountAccumulator].f0 += 1 + accumulator.asInstanceOf[CountAccumulator].f0 += 1L } } @@ -54,4 +57,8 @@ class CountAggFunction extends AggregateFunction[Long] { override def createAccumulator(): Accumulator = { new CountAccumulator } + + override def getAccumulatorType(): TypeInformation[_] = { + new TupleTypeInfo((new CountAccumulator).getClass, BasicTypeInfo.LONG_TYPE_INFO) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala index 20041ee..62ff88c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala @@ -19,9 +19,18 @@ package org.apache.flink.table.functions.aggfunctions import java.math.BigDecimal import java.util.{List => JList} + +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} +import org.apache.flink.api.java.typeutils.TupleTypeInfo import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +/** The initial accumulator for Max aggregate function */ +class MaxAccumulator[T] extends JTuple2[T, Boolean] with Accumulator { + f0 = 0.asInstanceOf[T] //max + f1 = false +} + /** * Base class for built-in Max aggregate function * @@ -29,20 +38,14 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction} */ abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFunction[T] { - /** The initial accumulator for Max aggregate function */ - class MaxAccumulator extends JTuple2[T, Boolean] with Accumulator { - f0 = 0.asInstanceOf[T] //max - f1 = false - } - override def createAccumulator(): Accumulator = { - new MaxAccumulator + new MaxAccumulator[T] } override def accumulate(accumulator: Accumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[T] - val a = accumulator.asInstanceOf[MaxAccumulator] + val a = accumulator.asInstanceOf[MaxAccumulator[T]] if (!a.f1 || ord.compare(a.f0, v) < 0) { a.f0 = v a.f1 = true @@ -51,7 +54,7 @@ abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun } override def getValue(accumulator: Accumulator): T = { - val a = accumulator.asInstanceOf[MaxAccumulator] + val a = accumulator.asInstanceOf[MaxAccumulator[T]] if (a.f1) { a.f0 } else { @@ -63,50 +66,73 @@ abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun val ret = accumulators.get(0) var i: Int = 1 while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[MaxAccumulator] + val a = accumulators.get(i).asInstanceOf[MaxAccumulator[T]] if (a.f1) { - accumulate(ret.asInstanceOf[MaxAccumulator], a.f0) + accumulate(ret.asInstanceOf[MaxAccumulator[T]], a.f0) } i += 1 } ret } + + override def getAccumulatorType(): TypeInformation[_] = { + new TupleTypeInfo( + new MaxAccumulator[T].getClass, + getValueTypeInfo, + BasicTypeInfo.BOOLEAN_TYPE_INFO) + } + + def getValueTypeInfo: TypeInformation[_] } /** * Built-in Byte Max aggregate function */ -class ByteMaxAggFunction extends MaxAggFunction[Byte] +class ByteMaxAggFunction extends MaxAggFunction[Byte] { + override def getValueTypeInfo = BasicTypeInfo.BYTE_TYPE_INFO +} /** * Built-in Short Max aggregate function */ -class ShortMaxAggFunction extends MaxAggFunction[Short] +class ShortMaxAggFunction extends MaxAggFunction[Short] { + override def getValueTypeInfo = BasicTypeInfo.SHORT_TYPE_INFO +} /** * Built-in Int Max aggregate function */ -class IntMaxAggFunction extends MaxAggFunction[Int] +class IntMaxAggFunction extends MaxAggFunction[Int] { + override def getValueTypeInfo = BasicTypeInfo.INT_TYPE_INFO +} /** * Built-in Long Max aggregate function */ -class LongMaxAggFunction extends MaxAggFunction[Long] +class LongMaxAggFunction extends MaxAggFunction[Long] { + override def getValueTypeInfo = BasicTypeInfo.LONG_TYPE_INFO +} /** * Built-in Float Max aggregate function */ -class FloatMaxAggFunction extends MaxAggFunction[Float] +class FloatMaxAggFunction extends MaxAggFunction[Float] { + override def getValueTypeInfo = BasicTypeInfo.FLOAT_TYPE_INFO +} /** * Built-in Double Max aggregate function */ -class DoubleMaxAggFunction extends MaxAggFunction[Double] +class DoubleMaxAggFunction extends MaxAggFunction[Double] { + override def getValueTypeInfo = BasicTypeInfo.DOUBLE_TYPE_INFO +} /** * Built-in Boolean Max aggregate function */ -class BooleanMaxAggFunction extends MaxAggFunction[Boolean] +class BooleanMaxAggFunction extends MaxAggFunction[Boolean] { + override def getValueTypeInfo = BasicTypeInfo.BOOLEAN_TYPE_INFO +} /** * Built-in Big Decimal Max aggregate function @@ -116,11 +142,13 @@ class DecimalMaxAggFunction extends MaxAggFunction[BigDecimal] { override def accumulate(accumulator: Accumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[BigDecimal] - val accum = accumulator.asInstanceOf[MaxAccumulator] + val accum = accumulator.asInstanceOf[MaxAccumulator[BigDecimal]] if (!accum.f1 || accum.f0.compareTo(v) < 0) { accum.f0 = v accum.f1 = true } } } + + override def getValueTypeInfo = BasicTypeInfo.BIG_DEC_TYPE_INFO } http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala index 16461ae..cddb873 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala @@ -19,9 +19,18 @@ package org.apache.flink.table.functions.aggfunctions import java.math.BigDecimal import java.util.{List => JList} + +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} +import org.apache.flink.api.java.typeutils.TupleTypeInfo import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +/** The initial accumulator for Min aggregate function */ +class MinAccumulator[T] extends JTuple2[T, Boolean] with Accumulator { + f0 = 0.asInstanceOf[T] //min + f1 = false +} + /** * Base class for built-in Min aggregate function * @@ -29,20 +38,14 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction} */ abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFunction[T] { - /** The initial accumulator for Min aggregate function */ - class MinAccumulator extends JTuple2[T, Boolean] with Accumulator { - f0 = 0.asInstanceOf[T] //min - f1 = false - } - override def createAccumulator(): Accumulator = { - new MinAccumulator + new MinAccumulator[T] } override def accumulate(accumulator: Accumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[T] - val a = accumulator.asInstanceOf[MinAccumulator] + val a = accumulator.asInstanceOf[MinAccumulator[T]] if (!a.f1 || ord.compare(a.f0, v) > 0) { a.f0 = v a.f1 = true @@ -51,7 +54,7 @@ abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun } override def getValue(accumulator: Accumulator): T = { - val a = accumulator.asInstanceOf[MinAccumulator] + val a = accumulator.asInstanceOf[MinAccumulator[T]] if (a.f1) { a.f0 } else { @@ -63,50 +66,73 @@ abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun val ret = accumulators.get(0) var i: Int = 1 while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[MinAccumulator] + val a = accumulators.get(i).asInstanceOf[MinAccumulator[T]] if (a.f1) { - accumulate(ret.asInstanceOf[MinAccumulator], a.f0) + accumulate(ret.asInstanceOf[MinAccumulator[T]], a.f0) } i += 1 } ret } + + override def getAccumulatorType(): TypeInformation[_] = { + new TupleTypeInfo( + new MinAccumulator[T].getClass, + getValueTypeInfo, + BasicTypeInfo.BOOLEAN_TYPE_INFO) + } + + def getValueTypeInfo: TypeInformation[_] } /** * Built-in Byte Min aggregate function */ -class ByteMinAggFunction extends MinAggFunction[Byte] +class ByteMinAggFunction extends MinAggFunction[Byte] { + override def getValueTypeInfo = BasicTypeInfo.BYTE_TYPE_INFO +} /** * Built-in Short Min aggregate function */ -class ShortMinAggFunction extends MinAggFunction[Short] +class ShortMinAggFunction extends MinAggFunction[Short] { + override def getValueTypeInfo = BasicTypeInfo.SHORT_TYPE_INFO +} /** * Built-in Int Min aggregate function */ -class IntMinAggFunction extends MinAggFunction[Int] +class IntMinAggFunction extends MinAggFunction[Int] { + override def getValueTypeInfo = BasicTypeInfo.INT_TYPE_INFO +} /** * Built-in Long Min aggregate function */ -class LongMinAggFunction extends MinAggFunction[Long] +class LongMinAggFunction extends MinAggFunction[Long] { + override def getValueTypeInfo = BasicTypeInfo.LONG_TYPE_INFO +} /** * Built-in Float Min aggregate function */ -class FloatMinAggFunction extends MinAggFunction[Float] +class FloatMinAggFunction extends MinAggFunction[Float] { + override def getValueTypeInfo = BasicTypeInfo.FLOAT_TYPE_INFO +} /** * Built-in Double Min aggregate function */ -class DoubleMinAggFunction extends MinAggFunction[Double] +class DoubleMinAggFunction extends MinAggFunction[Double] { + override def getValueTypeInfo = BasicTypeInfo.DOUBLE_TYPE_INFO +} /** * Built-in Boolean Min aggregate function */ -class BooleanMinAggFunction extends MinAggFunction[Boolean] +class BooleanMinAggFunction extends MinAggFunction[Boolean] { + override def getValueTypeInfo = BasicTypeInfo.BOOLEAN_TYPE_INFO +} /** * Built-in Big Decimal Min aggregate function @@ -116,11 +142,13 @@ class DecimalMinAggFunction extends MinAggFunction[BigDecimal] { override def accumulate(accumulator: Accumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[BigDecimal] - val accum = accumulator.asInstanceOf[MinAccumulator] + val accum = accumulator.asInstanceOf[MinAccumulator[BigDecimal]] if (!accum.f1 || accum.f0.compareTo(v) > 0) { accum.f0 = v accum.f1 = true } } } + + override def getValueTypeInfo = BasicTypeInfo.BIG_DEC_TYPE_INFO } http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala index b04d8c0..78fdb8e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala @@ -19,9 +19,15 @@ package org.apache.flink.table.functions.aggfunctions import java.math.BigDecimal import java.util.{List => JList} + +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} +import org.apache.flink.api.java.typeutils.TupleTypeInfo import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +/** The initial accumulator for Sum aggregate function */ +class SumAccumulator[T] extends JTuple2[T, Boolean] with Accumulator + /** * Base class for built-in Sum aggregate function * @@ -29,29 +35,26 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction} */ abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T] { - /** The initial accumulator for Sum aggregate function */ - class SumAccumulator extends JTuple2[T, Boolean] with Accumulator { - f0 = numeric.zero //sum - f1 = false - } - private val numeric = implicitly[Numeric[T]] override def createAccumulator(): Accumulator = { - new SumAccumulator + val acc = new SumAccumulator[T]() + acc.f0 = numeric.zero //sum + acc.f1 = false + acc } override def accumulate(accumulator: Accumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[T] - val a = accumulator.asInstanceOf[SumAccumulator] + val a = accumulator.asInstanceOf[SumAccumulator[T]] a.f0 = numeric.plus(v, a.f0) a.f1 = true } } override def getValue(accumulator: Accumulator): T = { - val a = accumulator.asInstanceOf[SumAccumulator] + val a = accumulator.asInstanceOf[SumAccumulator[T]] if (a.f1) { a.f0 } else { @@ -60,10 +63,10 @@ abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T] { } override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = createAccumulator().asInstanceOf[SumAccumulator] + val ret = createAccumulator().asInstanceOf[SumAccumulator[T]] var i: Int = 0 while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[SumAccumulator] + val a = accumulators.get(i).asInstanceOf[SumAccumulator[T]] if (a.f1) { ret.f0 = numeric.plus(ret.f0, a.f0) ret.f1 = true @@ -72,50 +75,70 @@ abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T] { } ret } + + override def getAccumulatorType(): TypeInformation[_] = { + new TupleTypeInfo( + (new SumAccumulator).getClass, + getValueTypeInfo, + BasicTypeInfo.BOOLEAN_TYPE_INFO) + } + + def getValueTypeInfo: TypeInformation[_] } /** * Built-in Byte Sum aggregate function */ -class ByteSumAggFunction extends SumAggFunction[Byte] +class ByteSumAggFunction extends SumAggFunction[Byte] { + override def getValueTypeInfo = BasicTypeInfo.BYTE_TYPE_INFO +} /** * Built-in Short Sum aggregate function */ -class ShortSumAggFunction extends SumAggFunction[Short] +class ShortSumAggFunction extends SumAggFunction[Short] { + override def getValueTypeInfo = BasicTypeInfo.SHORT_TYPE_INFO +} /** * Built-in Int Sum aggregate function */ -class IntSumAggFunction extends SumAggFunction[Int] +class IntSumAggFunction extends SumAggFunction[Int] { + override def getValueTypeInfo = BasicTypeInfo.INT_TYPE_INFO +} /** * Built-in Long Sum aggregate function */ -class LongSumAggFunction extends SumAggFunction[Long] +class LongSumAggFunction extends SumAggFunction[Long] { + override def getValueTypeInfo = BasicTypeInfo.LONG_TYPE_INFO +} /** * Built-in Float Sum aggregate function */ -class FloatSumAggFunction extends SumAggFunction[Float] +class FloatSumAggFunction extends SumAggFunction[Float] { + override def getValueTypeInfo = BasicTypeInfo.FLOAT_TYPE_INFO +} /** * Built-in Double Sum aggregate function */ -class DoubleSumAggFunction extends SumAggFunction[Double] +class DoubleSumAggFunction extends SumAggFunction[Double] { + override def getValueTypeInfo = BasicTypeInfo.DOUBLE_TYPE_INFO +} +/** The initial accumulator for Big Decimal Sum aggregate function */ +class DecimalSumAccumulator extends JTuple2[BigDecimal, Boolean] with Accumulator { + f0 = BigDecimal.ZERO + f1 = false +} /** * Built-in Big Decimal Sum aggregate function */ class DecimalSumAggFunction extends AggregateFunction[BigDecimal] { - /** The initial accumulator for Big Decimal Sum aggregate function */ - class DecimalSumAccumulator extends JTuple2[BigDecimal, Boolean] with Accumulator { - f0 = BigDecimal.ZERO - f1 = false - } - override def createAccumulator(): Accumulator = { new DecimalSumAccumulator } @@ -150,4 +173,11 @@ class DecimalSumAggFunction extends AggregateFunction[BigDecimal] { } ret } + + override def getAccumulatorType(): TypeInformation[_] = { + new TupleTypeInfo( + (new DecimalSumAccumulator).getClass, + BasicTypeInfo.BIG_DEC_TYPE_INFO, + BasicTypeInfo.BOOLEAN_TYPE_INFO) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala index aec4fbb..21d28b5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala @@ -119,9 +119,9 @@ object UserDefinedFunctionUtils { } /** - * Check if a given method exits in the given function + * Check if a given method exists in the given function */ - def ifMethodExitInFunction(method: String, function: UserDefinedFunction): Boolean = { + def ifMethodExistInFunction(method: String, function: UserDefinedFunction): Boolean = { val methods = function .getClass .getMethods http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala index 206e562..a88bcfe 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala @@ -44,9 +44,7 @@ class DataSetAggregate( inputType: RelDataType, grouping: Array[Int], inGroupingSet: Boolean) - extends SingleRel(cluster, traitSet, inputNode) - with CommonAggregate - with DataSetRel { + extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataSetRel { override def deriveRowType(): RelDataType = rowRelDataType @@ -63,11 +61,13 @@ class DataSetAggregate( } override def toString: String = { - s"Aggregate(${ if (!grouping.isEmpty) { - s"groupBy: (${groupingToString(inputType, grouping)}), " - } else { - "" - }}select: (${aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)}))" + s"Aggregate(${ + if (!grouping.isEmpty) { + s"groupBy: (${groupingToString(inputType, grouping)}), " + } else { + "" + } + }select: (${aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)}))" } override def explainTerms(pw: RelWriter): RelWriter = { @@ -76,7 +76,7 @@ class DataSetAggregate( .item("select", aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)) } - override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { + override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { val child = this.getInput val rowCnt = metadata.getRowCount(child) @@ -87,8 +87,6 @@ class DataSetAggregate( override def translateToPlan(tableEnv: BatchTableEnvironment): DataSet[Row] = { - val config = tableEnv.getConfig - val groupingKeys = grouping.indices.toArray val mapFunction = AggregateUtil.createPrepareMapFunction( @@ -107,9 +105,7 @@ class DataSetAggregate( val aggString = aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil) val prepareOpName = s"prepare select: ($aggString)" - val mappedInput = inputDS - .map(mapFunction) - .name(prepareOpName) + val mappedInput = inputDS.map(mapFunction).name(prepareOpName) val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo] @@ -127,6 +123,7 @@ class DataSetAggregate( else { // global aggregation val aggOpName = s"select:($aggString)" + mappedInput.asInstanceOf[DataSet[Row]] .reduceGroup(groupReduceFunction) .returns(rowTypeInfo) http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala index 48de822..597be8c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala @@ -47,9 +47,7 @@ class DataSetWindowAggregate( rowRelDataType: RelDataType, inputType: RelDataType, grouping: Array[Int]) - extends SingleRel(cluster, traitSet, inputNode) - with CommonAggregate - with DataSetRel { + extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataSetRel { override def deriveRowType() = rowRelDataType @@ -97,7 +95,7 @@ class DataSetWindowAggregate( namedProperties)) } - override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { + override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { val child = this.getInput val rowCnt = metadata.getRowCount(child) val rowSize = this.estimateRowSize(child.getRowType) @@ -136,8 +134,8 @@ class DataSetWindowAggregate( private def createEventTimeTumblingWindowDataSet( inputDS: DataSet[Row], isTimeWindow: Boolean, - isParserCaseSensitive: Boolean) - : DataSet[Row] = { + isParserCaseSensitive: Boolean): DataSet[Row] = { + val mapFunction = createDataSetWindowPrepareMapFunction( window, namedAggregates, @@ -191,8 +189,7 @@ class DataSetWindowAggregate( private[this] def createEventTimeSessionWindowDataSet( inputDS: DataSet[Row], - isParserCaseSensitive: Boolean) - : DataSet[Row] = { + isParserCaseSensitive: Boolean): DataSet[Row] = { val groupingKeys = grouping.indices.toArray val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) @@ -207,10 +204,7 @@ class DataSetWindowAggregate( inputType, isParserCaseSensitive) - val mappedInput = - inputDS - .map(mapFunction) - .name(prepareOperatorName) + val mappedInput = inputDS.map(mapFunction).name(prepareOperatorName) val mapReturnType = mapFunction.asInstanceOf[ResultTypeQueryable[Row]].getProducedType @@ -218,7 +212,7 @@ class DataSetWindowAggregate( val rowTimeFieldPos = mapReturnType.getArity - 1 // do incremental aggregation - if (doAllSupportPartialAggregation( + if (doAllSupportPartialMerge( namedAggregates.map(_.getKey), inputType, grouping.length)) { @@ -267,10 +261,10 @@ class DataSetWindowAggregate( namedProperties) mappedInput.groupBy(groupingKeys: _*) - .sortGroup(rowTimeFieldPos, Order.ASCENDING) - .reduceGroup(groupReduceFunction) - .returns(rowTypeInfo) - .name(aggregateOperatorName) + .sortGroup(rowTimeFieldPos, Order.ASCENDING) + .reduceGroup(groupReduceFunction) + .returns(rowTypeInfo) + .name(aggregateOperatorName) } } // non-grouping window http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala index c21d008..50f8281 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala @@ -50,9 +50,7 @@ class DataStreamAggregate( rowRelDataType: RelDataType, inputType: RelDataType, grouping: Array[Int]) - extends SingleRel(cluster, traitSet, inputNode) - with CommonAggregate - with DataStreamRel { + extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataStreamRel { override def deriveRowType(): RelDataType = rowRelDataType @@ -91,12 +89,13 @@ class DataStreamAggregate( super.explainTerms(pw) .itemIf("groupBy", groupingToString(inputType, grouping), !grouping.isEmpty) .item("window", window) - .item("select", aggregationToString( - inputType, - grouping, - getRowType, - namedAggregates, - namedProperties)) + .item( + "select", aggregationToString( + inputType, + grouping, + getRowType, + namedAggregates, + namedProperties)) } override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = { @@ -113,116 +112,61 @@ class DataStreamAggregate( namedAggregates, namedProperties) - val prepareOpName = s"prepare select: ($aggString)" val keyedAggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " + s"window: ($window), " + s"select: ($aggString)" val nonKeyedAggOpName = s"window: ($window), select: ($aggString)" - val mapFunction = AggregateUtil.createPrepareMapFunction( - namedAggregates, - grouping, - inputType) - - val mappedInput = inputDS.map(mapFunction).name(prepareOpName) - - - // check whether all aggregates support partial aggregate - if (AggregateUtil.doAllSupportPartialAggregation( - namedAggregates.map(_.getKey), - inputType, - grouping.length)) { - // do Incremental Aggregation - val reduceFunction = AggregateUtil.createIncrementalAggregateReduceFunction( - namedAggregates, - inputType, - getRowType, - grouping) - // grouped / keyed aggregation - if (groupingKeys.length > 0) { - val windowFunction = AggregateUtil.createWindowIncrementalAggregationFunction( - window, - namedAggregates, - inputType, - rowRelDataType, - grouping, - namedProperties) - - val keyedStream = mappedInput.keyBy(groupingKeys: _*) - val windowedStream = - createKeyedWindowedStream(window, keyedStream) + // grouped / keyed aggregation + if (groupingKeys.length > 0) { + val windowFunction = AggregateUtil.createAggregationGroupWindowFunction( + window, + groupingKeys.length, + namedAggregates.size, + rowRelDataType.getFieldCount, + namedProperties) + + val keyedStream = inputDS.keyBy(groupingKeys: _*) + val windowedStream = + createKeyedWindowedStream(window, keyedStream) .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]] - windowedStream - .reduce(reduceFunction, windowFunction) - .returns(rowTypeInfo) - .name(keyedAggOpName) - } - // global / non-keyed aggregation - else { - val windowFunction = AggregateUtil.createAllWindowIncrementalAggregationFunction( - window, + val (aggFunction, accumulatorRowType, aggResultRowType) = + AggregateUtil.createDataStreamAggregateFunction( namedAggregates, inputType, rowRelDataType, - grouping, - namedProperties) + grouping) - val windowedStream = - createNonKeyedWindowedStream(window, mappedInput) - .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]] - - windowedStream - .reduce(reduceFunction, windowFunction) - .returns(rowTypeInfo) - .name(nonKeyedAggOpName) - } + windowedStream + .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo) + .name(keyedAggOpName) } + // global / non-keyed aggregation else { - // do non-Incremental Aggregation - // grouped / keyed aggregation - if (groupingKeys.length > 0) { - - val windowFunction = AggregateUtil.createWindowAggregationFunction( - window, - namedAggregates, - inputType, - rowRelDataType, - grouping, - namedProperties) + val windowFunction = AggregateUtil.createAggregationAllWindowFunction( + window, + rowRelDataType.getFieldCount, + namedProperties) - val keyedStream = mappedInput.keyBy(groupingKeys: _*) - val windowedStream = - createKeyedWindowedStream(window, keyedStream) - .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]] + val windowedStream = + createNonKeyedWindowedStream(window, inputDS) + .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]] - windowedStream - .apply(windowFunction) - .returns(rowTypeInfo) - .name(keyedAggOpName) - } - // global / non-keyed aggregation - else { - val windowFunction = AggregateUtil.createAllWindowAggregationFunction( - window, + val (aggFunction, accumulatorRowType, aggResultRowType) = + AggregateUtil.createDataStreamAggregateFunction( namedAggregates, inputType, rowRelDataType, - grouping, - namedProperties) - - val windowedStream = - createNonKeyedWindowedStream(window, mappedInput) - .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]] + grouping) - windowedStream - .apply(windowFunction) - .returns(rowTypeInfo) + windowedStream + .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo) .name(nonKeyedAggOpName) - } } } } + object DataStreamAggregate { @@ -242,8 +186,8 @@ object DataStreamAggregate { // TODO: EventTimeTumblingGroupWindow should sort the stream on event time // before applying the windowing logic. Otherwise, this would be the same as a // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " + - "currently not supported.") + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) => stream.window(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide))) @@ -258,8 +202,8 @@ object DataStreamAggregate { // TODO: EventTimeTumblingGroupWindow should sort the stream on event time // before applying the windowing logic. Otherwise, this would be the same as a // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " + - "currently not supported.") + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") case ProcessingTimeSessionGroupWindow(_, gap: Expression) => stream.window(ProcessingTimeSessionWindows.withGap(asTime(gap))) @@ -284,8 +228,8 @@ object DataStreamAggregate { // TODO: EventTimeTumblingGroupWindow should sort the stream on event time // before applying the windowing logic. Otherwise, this would be the same as a // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " + - "currently not supported.") + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) => stream.windowAll(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide))) @@ -300,8 +244,8 @@ object DataStreamAggregate { // TODO: EventTimeTumblingGroupWindow should sort the stream on event time // before applying the windowing logic. Otherwise, this would be the same as a // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " + - "currently not supported.") + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") case ProcessingTimeSessionGroupWindow(_, gap) => stream.windowAll(ProcessingTimeSessionWindows.withGap(asTime(gap))) http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala new file mode 100644 index 0000000..4d1579b --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.aggregate + +import java.util.{ArrayList => JArrayList, List => JList} +import org.apache.flink.api.common.functions.{AggregateFunction => DataStreamAggFunc} +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.types.Row + +/** + * Aggregate Function used for the aggregate operator in + * [[org.apache.flink.streaming.api.datastream.WindowedStream]] + * + * @param aggregates the list of all [[org.apache.flink.table.functions.AggregateFunction]] + * used for this aggregation + * @param aggFields the position (in the input Row) of the input value for each aggregate + */ +class AggregateAggFunction( + private val aggregates: Array[AggregateFunction[_]], + private val aggFields: Array[Int]) + extends DataStreamAggFunc[Row, Row, Row] { + + val aggsWithIdx: Array[(AggregateFunction[_], Int)] = aggregates.zipWithIndex + + override def createAccumulator(): Row = { + val accumulatorRow: Row = new Row(aggregates.length) + aggsWithIdx.foreach { case (agg, i) => + accumulatorRow.setField(i, agg.createAccumulator()) + } + accumulatorRow + } + + override def add(value: Row, accumulatorRow: Row) = { + + aggsWithIdx.foreach { case (agg, i) => + val acc = accumulatorRow.getField(i).asInstanceOf[Accumulator] + val v = value.getField(aggFields(i)) + agg.accumulate(acc, v) + } + } + + override def getResult(accumulatorRow: Row): Row = { + val output = new Row(aggFields.length) + + aggsWithIdx.foreach { case (agg, i) => + output.setField(i, agg.getValue(accumulatorRow.getField(i).asInstanceOf[Accumulator])) + } + output + } + + override def merge(aAccumulatorRow: Row, bAccumulatorRow: Row): Row = { + + aggsWithIdx.foreach { case (agg, i) => + val aAcc = aAccumulatorRow.getField(i).asInstanceOf[Accumulator] + val bAcc = bAccumulatorRow.getField(i).asInstanceOf[Accumulator] + val accumulators: JList[Accumulator] = new JArrayList[Accumulator]() + accumulators.add(aAcc) + accumulators.add(bAcc) + aAccumulatorRow.setField(i, agg.merge(accumulators)) + } + aAccumulatorRow + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala deleted file mode 100644 index 89f3b41..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllTimeWindowFunction.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.runtime.aggregate - -import java.lang.Iterable - -import org.apache.flink.api.common.functions.RichGroupReduceFunction -import org.apache.flink.types.Row -import org.apache.flink.configuration.Configuration -import org.apache.flink.streaming.api.windowing.windows.TimeWindow -import org.apache.flink.util.Collector - -class AggregateAllTimeWindowFunction( - groupReduceFunction: RichGroupReduceFunction[Row, Row], - windowStartPos: Option[Int], - windowEndPos: Option[Int]) - extends AggregateAllWindowFunction[TimeWindow](groupReduceFunction) { - - private var collector: TimeWindowPropertyCollector = _ - - override def open(parameters: Configuration): Unit = { - collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos) - super.open(parameters) - } - - override def apply(window: TimeWindow, input: Iterable[Row], out: Collector[Row]): Unit = { - - // set collector and window - collector.wrappedCollector = out - collector.windowStart = window.getStart - collector.windowEnd = window.getEnd - - // call wrapped reduce function with property collector - super.apply(window, input, collector) - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala deleted file mode 100644 index 10a06da..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAllWindowFunction.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.runtime.aggregate - -import java.lang.Iterable - -import org.apache.flink.api.common.functions.RichGroupReduceFunction -import org.apache.flink.types.Row -import org.apache.flink.configuration.Configuration -import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction -import org.apache.flink.streaming.api.windowing.windows.Window -import org.apache.flink.util.Collector - -class AggregateAllWindowFunction[W <: Window]( - groupReduceFunction: RichGroupReduceFunction[Row, Row]) - extends RichAllWindowFunction[Row, Row, W] { - - override def open(parameters: Configuration): Unit = { - groupReduceFunction.open(parameters) - } - - override def apply(window: W, input: Iterable[Row], out: Collector[Row]): Unit = { - groupReduceFunction.reduce(input, out) - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala index 0033ff7..d936fbb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala @@ -22,34 +22,36 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.types.Row import org.apache.flink.configuration.Configuration +import org.apache.flink.table.functions.AggregateFunction import org.apache.flink.util.Preconditions class AggregateMapFunction[IN, OUT]( - private val aggregates: Array[Aggregate[_]], + private val aggregates: Array[AggregateFunction[_]], private val aggFields: Array[Int], private val groupingKeys: Array[Int], @transient private val returnType: TypeInformation[OUT]) - extends RichMapFunction[IN, OUT] - with ResultTypeQueryable[OUT] { - + extends RichMapFunction[IN, OUT] with ResultTypeQueryable[OUT] { + private var output: Row = _ - + override def open(config: Configuration) { Preconditions.checkNotNull(aggregates) Preconditions.checkNotNull(aggFields) Preconditions.checkArgument(aggregates.length == aggFields.length) - val partialRowLength = groupingKeys.length + - aggregates.map(_.intermediateDataType.length).sum + val partialRowLength = groupingKeys.length + aggregates.length output = new Row(partialRowLength) } override def map(value: IN): OUT = { - + val input = value.asInstanceOf[Row] for (i <- aggregates.indices) { - val fieldValue = input.getField(aggFields(i)) - aggregates(i).prepare(fieldValue, output) + val agg = aggregates(i) + val accumulator = agg.createAccumulator() + agg.accumulate(accumulator, input.getField(aggFields(i))) + output.setField(groupingKeys.length + i, accumulator) } + for (i <- groupingKeys.indices) { output.setField(i, input.getField(groupingKeys(i))) } http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala index 5237ecf..06ac8fb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala @@ -19,61 +19,84 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable +import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.CombineFunction +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} import org.apache.flink.types.Row -import scala.collection.JavaConversions._ - /** - * It wraps the aggregate logic inside of - * [[org.apache.flink.api.java.operators.GroupReduceOperator]] and - * [[org.apache.flink.api.java.operators.GroupCombineOperator]] - * - * @param aggregates The aggregate functions. - * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row - * and output Row. - * @param aggregateMapping The index mapping between aggregate function list and aggregated value - * index in output Row. - * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate - * Row and output Row. - */ + * It wraps the aggregate logic inside of + * [[org.apache.flink.api.java.operators.GroupReduceOperator]] and + * [[org.apache.flink.api.java.operators.GroupCombineOperator]] + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated + * value + * index in output Row. + * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate + * Row and output Row. + * @param finalRowArity the arity of the final resulting row + */ class AggregateReduceCombineFunction( - private val aggregates: Array[Aggregate[_ <: Any]], + private val aggregates: Array[AggregateFunction[_ <: Any]], private val groupKeysMapping: Array[(Int, Int)], private val aggregateMapping: Array[(Int, Int)], private val groupingSetsMapping: Array[(Int, Int)], - private val intermediateRowArity: Int, private val finalRowArity: Int) extends AggregateReduceGroupFunction( aggregates, groupKeysMapping, aggregateMapping, groupingSetsMapping, - intermediateRowArity, - finalRowArity) - with CombineFunction[Row, Row] { + finalRowArity) with CombineFunction[Row, Row] { /** - * For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer, - * - * @param records Sub-grouped intermediate aggregate Rows iterator. - * @return Combined intermediate aggregate Row. - * - */ + * For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer, + * + * @param records Sub-grouped intermediate aggregate Rows iterator. + * @return Combined intermediate aggregate Row. + * + */ override def combine(records: Iterable[Row]): Row = { - // Initiate intermediate aggregate value. - aggregates.foreach(_.initiate(aggregateBuffer)) - - // Merge intermediate aggregate value to buffer. + // merge intermediate aggregate value to buffer. var last: Row = null - records.foreach((record) => { - aggregates.foreach(_.merge(record, aggregateBuffer)) + accumulatorList.foreach(_.clear()) + + val iterator = records.iterator() + + var count: Int = 0 + while (iterator.hasNext) { + val record = iterator.next() + count += 1 + // per each aggregator, collect its accumulators to a list + for (i <- aggregates.indices) { + accumulatorList(i).add(record.getField(groupKeysMapping.length + i) + .asInstanceOf[Accumulator]) + } + // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one + // accumulator + if (count > maxMergeLen) { + count = 0 + for (i <- aggregates.indices) { + val agg = aggregates(i) + val accumulator = agg.merge(accumulatorList(i)) + accumulatorList(i).clear() + accumulatorList(i).add(accumulator) + } + } last = record - }) + } + + for (i <- aggregates.indices) { + val agg = aggregates(i) + aggregateBuffer.setField(groupKeysMapping.length + i, agg.merge(accumulatorList(i))) + } - // Set group keys to aggregateBuffer. + // set group keys to aggregateBuffer. for (i <- groupKeysMapping.indices) { aggregateBuffer.setField(i, last.getField(i)) } http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala index c147629..23b5236 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala @@ -18,43 +18,48 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable +import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.configuration.Configuration +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} import org.apache.flink.types.Row import org.apache.flink.util.{Collector, Preconditions} -import scala.collection.JavaConversions._ - /** - * It wraps the aggregate logic inside of - * [[org.apache.flink.api.java.operators.GroupReduceOperator]]. - * - * @param aggregates The aggregate functions. - * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row - * and output Row. - * @param aggregateMapping The index mapping between aggregate function list and aggregated value - * index in output Row. - * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate - * Row and output Row. - */ + * It wraps the aggregate logic inside of + * [[org.apache.flink.api.java.operators.GroupReduceOperator]]. + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated + * value + * index in output Row. + * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate + * Row and output Row. + * @param finalRowArity The arity of the final resulting row + */ class AggregateReduceGroupFunction( - private val aggregates: Array[Aggregate[_ <: Any]], + private val aggregates: Array[AggregateFunction[_ <: Any]], private val groupKeysMapping: Array[(Int, Int)], private val aggregateMapping: Array[(Int, Int)], private val groupingSetsMapping: Array[(Int, Int)], - private val intermediateRowArity: Int, private val finalRowArity: Int) extends RichGroupReduceFunction[Row, Row] { protected var aggregateBuffer: Row = _ private var output: Row = _ private var intermediateGroupKeys: Option[Array[Int]] = None + protected val maxMergeLen = 16 + val accumulatorList = Array.fill(aggregates.length) { + new JArrayList[Accumulator]() + } override def open(config: Configuration) { Preconditions.checkNotNull(aggregates) Preconditions.checkNotNull(groupKeysMapping) - aggregateBuffer = new Row(intermediateRowArity) + aggregateBuffer = new Row(aggregates.length + groupKeysMapping.length) output = new Row(finalRowArity) if (!groupingSetsMapping.isEmpty) { intermediateGroupKeys = Some(groupKeysMapping.map(_._1)) @@ -62,25 +67,44 @@ class AggregateReduceGroupFunction( } /** - * For grouped intermediate aggregate Rows, merge all of them into aggregate buffer, - * calculate aggregated values output by aggregate buffer, and set them into output - * Row based on the mapping relation between intermediate aggregate data and output data. - * - * @param records Grouped intermediate aggregate Rows iterator. - * @param out The collector to hand results to. - * - */ + * For grouped intermediate aggregate Rows, merge all of them into aggregate buffer, + * calculate aggregated values output by aggregate buffer, and set them into output + * Row based on the mapping relation between intermediate aggregate data and output data. + * + * @param records Grouped intermediate aggregate Rows iterator. + * @param out The collector to hand results to. + * + */ override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { - // Initiate intermediate aggregate value. - aggregates.foreach(_.initiate(aggregateBuffer)) - - // Merge intermediate aggregate value to buffer. + // merge intermediate aggregate value to buffer. var last: Row = null - records.foreach((record) => { - aggregates.foreach(_.merge(record, aggregateBuffer)) + accumulatorList.foreach(_.clear()) + + val iterator = records.iterator() + + var count: Int = 0 + while (iterator.hasNext) { + val record = iterator.next() + count += 1 + // per each aggregator, collect its accumulators to a list + for (i <- aggregates.indices) { + accumulatorList(i).add(record.getField(groupKeysMapping.length + i) + .asInstanceOf[Accumulator]) + } + // if the number of buffered accumulators is bigger than maxMergeLen, merge them into one + // accumulator + if (count > maxMergeLen) { + count = 0 + for (i <- aggregates.indices) { + val agg = aggregates(i) + val accumulator = agg.merge(accumulatorList(i)) + accumulatorList(i).clear() + accumulatorList(i).add(accumulator) + } + } last = record - }) + } // Set group keys value to final output. groupKeysMapping.foreach { @@ -88,10 +112,14 @@ class AggregateReduceGroupFunction( output.setField(after, last.getField(previous)) } - // Evaluate final aggregate value and set to output. + // get final aggregate value and set to output. aggregateMapping.foreach { - case (after, previous) => - output.setField(after, aggregates(previous).evaluate(aggregateBuffer)) + case (after, previous) => { + val agg = aggregates(previous) + val accumulator = agg.merge(accumulatorList(previous)) + val result = agg.getValue(accumulator) + output.setField(after, result) + } } // Evaluate additional values of grouping sets http://git-wip-us.apache.org/repos/asf/flink/blob/438276de/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala deleted file mode 100644 index 8f96848..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateTimeWindowFunction.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.runtime.aggregate - -import java.lang.Iterable - -import org.apache.flink.api.common.functions.RichGroupReduceFunction -import org.apache.flink.api.java.tuple.Tuple -import org.apache.flink.types.Row -import org.apache.flink.configuration.Configuration -import org.apache.flink.streaming.api.windowing.windows.TimeWindow -import org.apache.flink.util.Collector - -class AggregateTimeWindowFunction( - groupReduceFunction: RichGroupReduceFunction[Row, Row], - windowStartPos: Option[Int], - windowEndPos: Option[Int]) - extends AggregateWindowFunction[TimeWindow](groupReduceFunction) { - - private var collector: TimeWindowPropertyCollector = _ - - override def open(parameters: Configuration): Unit = { - collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos) - super.open(parameters) - } - - override def apply( - key: Tuple, - window: TimeWindow, - input: Iterable[Row], - out: Collector[Row]): Unit = { - - // set collector and window - collector.wrappedCollector = out - collector.windowStart = window.getStart - collector.windowEnd = window.getEnd - - // call wrapped reduce function with property collector - super.apply(key, window, input, collector) - } -}
