Repository: flink Updated Branches: refs/heads/master 404e37d21 -> 861c57cb1
[FLINK-7371] [table] Add support for constant parameters in OVER aggregate This closes #4736. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/861c57cb Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/861c57cb Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/861c57cb Branch: refs/heads/master Commit: 861c57cb10a3ccc862ec068869582bc1551feaa5 Parents: 404e37d Author: twalthr <[email protected]> Authored: Wed Sep 27 17:11:28 2017 +0200 Committer: twalthr <[email protected]> Committed: Mon Oct 16 17:57:16 2017 +0200 ---------------------------------------------------------------------- .../codegen/AggregationCodeGenerator.scala | 68 ++++++++++++++------ .../flink/table/codegen/CodeGenerator.scala | 50 +++++++++++--- .../flink/table/plan/nodes/OverAggregate.scala | 24 +++++-- .../plan/nodes/dataset/DataSetAggregate.scala | 3 +- .../nodes/dataset/DataSetWindowAggregate.scala | 3 +- .../plan/nodes/datastream/DataStreamCalc.scala | 2 +- .../datastream/DataStreamGroupAggregate.scala | 3 +- .../DataStreamGroupWindowAggregate.scala | 3 +- .../datastream/DataStreamOverAggregate.scala | 14 +++- .../utils/JavaUserDefinedAggFunctions.java | 7 ++ .../runtime/stream/table/OverWindowITCase.scala | 39 +++++++++++ 11 files changed, 173 insertions(+), 43 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/861c57cb/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala index 22ce5ba..82a2420 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala @@ -20,6 +20,7 @@ package org.apache.flink.table.codegen import java.lang.reflect.{Modifier, ParameterizedType} import java.lang.{Iterable => JIterable} +import org.apache.calcite.rex.RexLiteral import org.apache.commons.codec.binary.Base64 import org.apache.flink.api.common.state.{State, StateDescriptor} import org.apache.flink.api.common.typeinfo.TypeInformation @@ -41,11 +42,13 @@ import scala.collection.mutable * @param config configuration that determines runtime behavior * @param nullableInput input(s) can be null. * @param input type information about the input of the Function + * @param constants constant expressions that act like a second input in the parameter indices. */ class AggregationCodeGenerator( config: TableConfig, nullableInput: Boolean, - input: TypeInformation[_ <: Any]) + input: TypeInformation[_ <: Any], + constants: Option[Seq[RexLiteral]]) extends CodeGenerator(config, nullableInput, input) { // set of statements for cleanup dataview that will be added only once @@ -81,25 +84,26 @@ class AggregationCodeGenerator( * @param needRetract a flag to indicate if the aggregate needs the retract method * @param needMerge a flag to indicate if the aggregate needs the merge method * @param needReset a flag to indicate if the aggregate needs the resetAccumulator method + * @param accConfig Data view specification for accumulators * * @return A GeneratedAggregationsFunction */ def generateAggregations( - name: String, - physicalInputTypes: Seq[TypeInformation[_]], - aggregates: Array[AggregateFunction[_ <: Any, _ <: Any]], - aggFields: Array[Array[Int]], - aggMapping: Array[Int], - partialResults: Boolean, - fwdMapping: Array[Int], - mergeMapping: Option[Array[Int]], - constantFlags: Option[Array[(Int, Boolean)]], - outputArity: Int, - needRetract: Boolean, - needMerge: Boolean, - needReset: Boolean, - accConfig: Option[Array[Seq[DataViewSpec[_]]]]) - : GeneratedAggregationsFunction = { + name: String, + physicalInputTypes: Seq[TypeInformation[_]], + aggregates: Array[AggregateFunction[_ <: Any, _ <: Any]], + aggFields: Array[Array[Int]], + aggMapping: Array[Int], + partialResults: Boolean, + fwdMapping: Array[Int], + mergeMapping: Option[Array[Int]], + constantFlags: Option[Array[(Int, Boolean)]], + outputArity: Int, + needRetract: Boolean, + needMerge: Boolean, + needReset: Boolean, + accConfig: Option[Array[Seq[DataViewSpec[_]]]]) + : GeneratedAggregationsFunction = { // get unique function name val funcName = newName(name) @@ -112,17 +116,41 @@ class AggregationCodeGenerator( } val accTypes = accTypeClasses.map(_.getCanonicalName) + // create constants + val constantExprs = constants.map(_.map(generateExpression)).getOrElse(Seq()) + val constantTypes = constantExprs.map(_.resultType) + val constantFields = constantExprs.map(addReusableBoxedConstant) + // get parameter lists for aggregation functions val parametersCode = aggFields.map { inFields => - val fields = for (f <- inFields) yield - s"(${CodeGenUtils.boxedTypeTermForTypeInfo(physicalInputTypes(f))}) input.getField($f)" + val fields = inFields.map { f => + // index to constant + if (f >= physicalInputTypes.length) { + constantFields(f - physicalInputTypes.length) + } + // index to input field + else { + s"(${CodeGenUtils.boxedTypeTermForTypeInfo(physicalInputTypes(f))}) input.getField($f)" + } + } + fields.mkString(", ") } // get method signatures val classes = UserDefinedFunctionUtils.typeInfoToClass(physicalInputTypes) + val constantClasses = UserDefinedFunctionUtils.typeInfoToClass(constantTypes) val methodSignaturesList = aggFields.map { inFields => - inFields.map(classes(_)) + inFields.map { f => + // index to constant + if (f >= physicalInputTypes.length) { + constantClasses(f - physicalInputTypes.length) + } + // index to input field + else { + classes(f) + } + } } // initialize and create data views @@ -219,7 +247,7 @@ class AggregationCodeGenerator( val dataViewFieldTerm = createDataViewTerm(i, dataViewField.getName) val field = s""" - | transient $dataViewTypeTerm $dataViewFieldTerm = null; + | final $dataViewTypeTerm $dataViewFieldTerm; |""".stripMargin reusableMemberStatements.add(field) http://git-wip-us.apache.org/repos/asf/flink/blob/861c57cb/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 3fead21..5e7ec32 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -1347,12 +1347,12 @@ abstract class CodeGenerator( val statement = ti match { case rt: RowTypeInfo => s""" - |transient ${ti.getTypeClass.getCanonicalName} $outRecordTerm = + |final ${ti.getTypeClass.getCanonicalName} $outRecordTerm = | new ${ti.getTypeClass.getCanonicalName}(${rt.getArity}); |""".stripMargin case _ => s""" - |${ti.getTypeClass.getCanonicalName} $outRecordTerm = + |final ${ti.getTypeClass.getCanonicalName} $outRecordTerm = | new ${ti.getTypeClass.getCanonicalName}(); |""".stripMargin } @@ -1372,7 +1372,7 @@ abstract class CodeGenerator( val fieldTerm = s"field_${clazz.getCanonicalName.replace('.', '$')}_$fieldName" val fieldExtraction = s""" - |transient java.lang.reflect.Field $fieldTerm = + |final java.lang.reflect.Field $fieldTerm = | org.apache.flink.api.java.typeutils.TypeExtractor.getDeclaredField( | ${clazz.getCanonicalName}.class, "$fieldName"); |""".stripMargin @@ -1401,7 +1401,7 @@ abstract class CodeGenerator( val fieldTerm = newName("decimal") val fieldDecimal = s""" - |transient java.math.BigDecimal $fieldTerm = + |final java.math.BigDecimal $fieldTerm = | new java.math.BigDecimal("${decimal.toString}"); |""".stripMargin reusableMemberStatements.add(fieldDecimal) @@ -1420,7 +1420,7 @@ abstract class CodeGenerator( val field = s""" - |transient java.util.Random $fieldTerm; + |final java.util.Random $fieldTerm; |""".stripMargin reusableMemberStatements.add(field) @@ -1460,7 +1460,7 @@ abstract class CodeGenerator( val field = s""" - |transient org.joda.time.format.DateTimeFormatter $fieldTerm; + |final org.joda.time.format.DateTimeFormatter $fieldTerm; |""".stripMargin reusableMemberStatements.add(field) @@ -1489,7 +1489,7 @@ abstract class CodeGenerator( val fieldFunction = s""" - |transient $classQualifier $fieldTerm = null; + |final $classQualifier $fieldTerm; |""".stripMargin reusableMemberStatements.add(fieldFunction) @@ -1536,7 +1536,7 @@ abstract class CodeGenerator( parameterTypes.zipWithIndex.foreach { case (t, index) => val classQualifier = t.getCanonicalName val fieldTerm = newName(s"instance_${classQualifier.replace('.', '$')}") - val field = s"transient $classQualifier $fieldTerm = null;" + val field = s"final $classQualifier $fieldTerm;" reusableMemberStatements.add(field) fieldTerms += fieldTerm parameters += s"$classQualifier arg$index" @@ -1557,7 +1557,7 @@ abstract class CodeGenerator( val initArray = classQualifier.replaceFirst("\\[", s"[$size") val fieldArray = s""" - |transient $classQualifier $fieldTerm = + |final $classQualifier $fieldTerm = | new $initArray; |""".stripMargin reusableMemberStatements.add(fieldArray) @@ -1664,7 +1664,7 @@ abstract class CodeGenerator( val field = s""" - |transient java.util.Set $fieldTerm = null; + |final java.util.Set $fieldTerm; |""".stripMargin reusableMemberStatements.add(field) @@ -1690,4 +1690,34 @@ abstract class CodeGenerator( fieldTerm } + + /** + * Adds a reusable constant to the member area of the generated [[Function]]. + * + * @param constant constant expression + * @return member variable term + */ + def addReusableBoxedConstant(constant: GeneratedExpression): String = { + require(constant.literal, "Literal expected") + + val fieldTerm = newName("constant") + + val boxed = generateOutputFieldBoxing(constant) + val boxedType = boxedTypeTermForTypeInfo(boxed.resultType) + + val field = + s""" + |final $boxedType $fieldTerm; + |""".stripMargin + reusableMemberStatements.add(field) + + val init = + s""" + |${boxed.code} + |$fieldTerm = ${boxed.resultTerm}; + |""".stripMargin + reusableInitStatements.add(init) + + fieldTerm + } } http://git-wip-us.apache.org/repos/asf/flink/blob/861c57cb/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala index 1048549..87ebd86 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala @@ -19,11 +19,10 @@ package org.apache.flink.table.plan.nodes import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rel.core.{AggregateCall, Window} import org.apache.calcite.rel.core.Window.Group +import org.apache.calcite.rel.core.{AggregateCall, Window} import org.apache.calcite.rel.{RelFieldCollation, RelNode} -import org.apache.calcite.rex.RexInputRef -import org.apache.flink.table.plan.schema.RowSchema +import org.apache.calcite.rex.{RexInputRef, RexLiteral} import org.apache.flink.table.runtime.aggregate.AggregateUtil._ import scala.collection.JavaConverters._ @@ -61,9 +60,11 @@ trait OverAggregate { } private[flink] def aggregationToString( - inputType: RelDataType, - rowType: RelDataType, - namedAggregates: Seq[CalcitePair[AggregateCall, String]]): String = { + inputType: RelDataType, + constants: Seq[RexLiteral], + rowType: RelDataType, + namedAggregates: Seq[CalcitePair[AggregateCall, String]]) + : String = { val inFields = inputType.getFieldNames.asScala val outFields = rowType.getFieldNames.asScala @@ -71,7 +72,16 @@ trait OverAggregate { val aggStrings = namedAggregates.map(_.getKey).map( a => s"${a.getAggregation}(${ if (a.getArgList.size() > 0) { - a.getArgList.asScala.map(inFields(_)).mkString(", ") + a.getArgList.asScala.map { arg => + // index to constant + if (arg >= inputType.getFieldCount) { + constants(arg - inputType.getFieldCount) + } + // index to input field + else { + inFields(arg) + } + }.mkString(", ") } else { "*" } http://git-wip-us.apache.org/repos/asf/flink/blob/861c57cb/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala index 37d1a51..bbc5746 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala @@ -98,7 +98,8 @@ class DataSetAggregate( val generator = new AggregationCodeGenerator( tableEnv.getConfig, false, - inputDS.getType) + inputDS.getType, + None) val ( preAgg: Option[DataSetPreAggFunction], http://git-wip-us.apache.org/repos/asf/flink/blob/861c57cb/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala index 38de368..66dcc56 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala @@ -112,7 +112,8 @@ class DataSetWindowAggregate( val generator = new AggregationCodeGenerator( tableEnv.getConfig, false, - inputDS.getType) + inputDS.getType, + None) // whether identifiers are matched case-sensitively val caseSensitive = tableEnv.getFrameworkConfig.getParserConfig.caseSensitive() http://git-wip-us.apache.org/repos/asf/flink/blob/861c57cb/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala index 45e6902..05a8f5c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala @@ -27,8 +27,8 @@ import org.apache.calcite.rex.RexProgram import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment} +import org.apache.flink.table.calcite.RelTimeIndicatorConverter import org.apache.flink.table.codegen.FunctionCodeGenerator -import org.apache.flink.table.calcite.{FlinkTypeFactory, RelTimeIndicatorConverter} import org.apache.flink.table.plan.nodes.CommonCalc import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.runtime.CRowProcessRunner http://git-wip-us.apache.org/repos/asf/flink/blob/861c57cb/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala index 58c9d82..742a7e4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala @@ -116,7 +116,8 @@ class DataStreamGroupAggregate( val generator = new AggregationCodeGenerator( tableEnv.getConfig, false, - inputSchema.typeInfo) + inputSchema.typeInfo, + None) val aggString = aggregationToString( inputSchema.relDataType, http://git-wip-us.apache.org/repos/asf/flink/blob/861c57cb/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala index b15350f..db15839 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala @@ -173,7 +173,8 @@ class DataStreamGroupWindowAggregate( val generator = new AggregationCodeGenerator( tableEnv.getConfig, false, - inputSchema.typeInfo) + inputSchema.typeInfo, + None) val needMerge = window match { case SessionGroupWindow(_, _, _) => true http://git-wip-us.apache.org/repos/asf/flink/blob/861c57cb/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala index 6234525..b9b3e3e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala @@ -25,6 +25,7 @@ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.Window.Group import org.apache.calcite.rel.core.{AggregateCall, Window} import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.calcite.rex.RexLiteral import org.apache.flink.api.java.functions.NullByteKeySelector import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, TableException} @@ -38,6 +39,8 @@ import org.apache.flink.table.runtime.aggregate._ import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} import org.apache.flink.table.util.Logging +import scala.collection.JavaConverters._ + class DataStreamOverAggregate( logicWindow: Window, cluster: RelOptCluster, @@ -73,6 +76,7 @@ class DataStreamOverAggregate( override def explainTerms(pw: RelWriter): RelWriter = { val overWindow: Group = logicWindow.groups.get(0) + val constants: Seq[RexLiteral] = logicWindow.constants.asScala val partitionKeys: Array[Int] = overWindow.keys.toArray val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates @@ -86,6 +90,7 @@ class DataStreamOverAggregate( .item( "select", aggregationToString( inputSchema.relDataType, + constants, schema.relDataType, namedAggregates)) } @@ -131,10 +136,13 @@ class DataStreamOverAggregate( "excessive state size. You may specify a retention time of 0 to not clean up the state.") } + val constants: Seq[RexLiteral] = logicWindow.constants.asScala + val generator = new AggregationCodeGenerator( tableEnv.getConfig, false, - inputSchema.typeInfo) + inputSchema.typeInfo, + Some(constants)) val timeType = schema.relDataType .getFieldList @@ -230,7 +238,9 @@ class DataStreamOverAggregate( isRowsClause: Boolean): DataStream[CRow] = { val overWindow: Group = logicWindow.groups.get(0) + val partitionKeys: Array[Int] = overWindow.keys.toArray + val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates val precedingOffset = @@ -280,6 +290,7 @@ class DataStreamOverAggregate( private def aggOpName = { val overWindow: Group = logicWindow.groups.get(0) + val constants: Seq[RexLiteral] = logicWindow.constants.asScala val partitionKeys: Array[Int] = overWindow.keys.toArray val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates @@ -296,6 +307,7 @@ class DataStreamOverAggregate( s"select: (${ aggregationToString( inputSchema.relDataType, + constants, schema.relDataType, namedAggregates) }))" http://git-wip-us.apache.org/repos/asf/flink/blob/861c57cb/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java index 61f43dc..abf2c49 100644 --- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java +++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java @@ -86,6 +86,13 @@ public class JavaUserDefinedAggFunctions { } // overloaded accumulate method + // dummy to test constants + public void accumulate(WeightedAvgAccum accumulator, long iValue, int iWeight, int x, String string) { + accumulator.sum += (iValue + Integer.parseInt(string)) * iWeight; + accumulator.count += iWeight; + } + + // overloaded accumulate method public void accumulate(WeightedAvgAccum accumulator, long iValue, int iWeight) { accumulator.sum += iValue * iWeight; accumulator.count += iWeight; http://git-wip-us.apache.org/repos/asf/flink/blob/861c57cb/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/OverWindowITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/OverWindowITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/OverWindowITCase.scala index 54971b2..3e6b0c6 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/OverWindowITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/OverWindowITCase.scala @@ -85,6 +85,45 @@ class OverWindowITCase extends StreamingWithStateTestBase { } @Test + def testOverWindowWithConstant(): Unit = { + + val data = List( + (1L, 1, "Hello"), + (2L, 2, "Hello"), + (3L, 3, "Hello"), + (4L, 4, "Hello"), + (5L, 5, "Hello"), + (6L, 6, "Hello"), + (7L, 7, "Hello World"), + (8L, 8, "Hello World"), + (8L, 8, "Hello World"), + (20L, 20, "Hello World")) + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + StreamITCase.clear + val stream = env.fromCollection(data) + val table = stream.toTable(tEnv, 'a, 'b, 'c, 'proctime.proctime) + val weightAvgFun = new WeightedAvg + + val windowedTable = table + .window( + Over partitionBy 'c orderBy 'proctime preceding UNBOUNDED_ROW as 'w) + .select('c, weightAvgFun('a, 42, 'b, "2") over 'w as 'wAvg) + + val results = windowedTable.toAppendStream[Row] + results.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = Seq( + "Hello World,12", "Hello World,9", "Hello World,9", "Hello World,9", "Hello,3", + "Hello,3", "Hello,4", "Hello,4", "Hello,5", "Hello,5") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test def testRowTimeUnBoundedPartitionedRangeOver(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env)
