Repository: flink Updated Branches: refs/heads/master 3fcc4e37c -> 7456d78d2
[FLINK-5804] [table] Add support for procTime non-partitioned OVER RANGE BETWEEN UNBOUNDED PRECEDING aggregation to SQL. This closes #3491. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/7456d78d Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/7456d78d Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/7456d78d Branch: refs/heads/master Commit: 7456d78d271b217c80d46e24029c55741807e51d Parents: 3fcc4e3 Author: é竹 <jincheng.su...@alibaba-inc.com> Authored: Wed Mar 8 10:52:43 2017 +0800 Committer: Fabian Hueske <fhue...@apache.org> Committed: Thu Mar 9 19:11:43 2017 +0100 ---------------------------------------------------------------------- .../datastream/DataStreamOverAggregate.scala | 14 ++- .../table/runtime/aggregate/AggregateUtil.scala | 27 +++-- ...rtitionedProcessingOverProcessFunction.scala | 106 +++++++++++++++++++ .../table/api/scala/stream/sql/SqlITCase.scala | 53 ++++++++++ .../scala/stream/sql/WindowAggregateTest.scala | 54 ++++++++++ 5 files changed, 243 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala index db115e0..34b3b0f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala @@ -143,10 +143,18 @@ class DataStreamOverAggregate( .name(aggOpName) .asInstanceOf[DataStream[Row]] } - // global non-partitioned aggregation + // non-partitioned aggregation else { - throw TableException( - "Non-partitioned processing time OVER aggregation is not supported yet.") + val processFunction = AggregateUtil.CreateUnboundedProcessingOverProcessFunction( + namedAggregates, + inputType, + false) + + inputDS + .process(processFunction).setParallelism(1).setMaxParallelism(1) + .returns(rowTypeInfo) + .name(aggOpName) + .asInstanceOf[DataStream[Row]] } result } http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index 6555143..b6b3445 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -53,15 +53,18 @@ object AggregateUtil { type JavaList[T] = java.util.List[T] /** - * Create an [[ProcessFunction]] to evaluate final aggregate value. + * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] to evaluate final + * aggregate value. * * @param namedAggregates List of calls to aggregate functions and their output field names * @param inputType Input row type - * @return [[UnboundedProcessingOverProcessFunction]] + * @param isPartitioned Flag to indicate whether the input is partitioned or not + * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]] */ private[flink] def CreateUnboundedProcessingOverProcessFunction( namedAggregates: Seq[CalcitePair[AggregateCall, String]], - inputType: RelDataType): UnboundedProcessingOverProcessFunction = { + inputType: RelDataType, + isPartitioned: Boolean = true): ProcessFunction[Row, Row] = { val (aggFields, aggregates) = transformToAggregateFunctions( @@ -72,11 +75,19 @@ object AggregateUtil { val aggregationStateType: RowTypeInfo = createDataSetAggregateBufferDataType(Array(), aggregates, inputType) - new UnboundedProcessingOverProcessFunction( - aggregates, - aggFields, - inputType.getFieldCount, - aggregationStateType) + if (isPartitioned) { + new UnboundedProcessingOverProcessFunction( + aggregates, + aggFields, + inputType.getFieldCount, + aggregationStateType) + } else { + new UnboundedNonPartitionedProcessingOverProcessFunction( + aggregates, + aggFields, + inputType.getFieldCount, + aggregationStateType) + } } /** http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala new file mode 100644 index 0000000..51c8315 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import org.apache.flink.api.common.state.{ListState, ListStateDescriptor} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.configuration.Configuration +import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext} +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.types.Row +import org.apache.flink.util.{Collector, Preconditions} + +/** + * Process Function used for the aggregate in + * [[org.apache.flink.streaming.api.datastream.DataStream]] + * + * @param aggregates the list of all [[org.apache.flink.table.functions.AggregateFunction]] + * used for this aggregation + * @param aggFields the position (in the input Row) of the input value for each aggregate + */ +class UnboundedNonPartitionedProcessingOverProcessFunction( + private val aggregates: Array[AggregateFunction[_]], + private val aggFields: Array[Int], + private val forwardedFieldCount: Int, + private val aggregationStateType: RowTypeInfo) + extends ProcessFunction[Row, Row] with CheckpointedFunction{ + + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(aggFields) + Preconditions.checkArgument(aggregates.length == aggFields.length) + + private var accumulators: Row = _ + private var output: Row = _ + private var state: ListState[Row] = null + + override def open(config: Configuration) { + output = new Row(forwardedFieldCount + aggregates.length) + if (null == accumulators) { + val it = state.get().iterator() + if (it.hasNext) { + accumulators = it.next() + } else { + accumulators = new Row(aggregates.length) + var i = 0 + while (i < aggregates.length) { + accumulators.setField(i, aggregates(i).createAccumulator()) + i += 1 + } + } + } + } + + override def processElement( + input: Row, + ctx: ProcessFunction[Row, Row]#Context, + out: Collector[Row]): Unit = { + + var i = 0 + while (i < forwardedFieldCount) { + output.setField(i, input.getField(i)) + i += 1 + } + + i = 0 + while (i < aggregates.length) { + val index = forwardedFieldCount + i + val accumulator = accumulators.getField(i).asInstanceOf[Accumulator] + aggregates(i).accumulate(accumulator, input.getField(aggFields(i))) + output.setField(index, aggregates(i).getValue(accumulator)) + i += 1 + } + + out.collect(output) + } + + override def snapshotState(context: FunctionSnapshotContext): Unit = { + state.clear() + if (null != accumulators) { + state.add(accumulators) + } + } + + override def initializeState(context: FunctionInitializationContext): Unit = { + val stateSerializer = + aggregationStateType.createSerializer(getRuntimeContext.getExecutionConfig) + val accumulatorsDescriptor = new ListStateDescriptor[Row]("overState", stateSerializer) + state = context.getOperatorStateStore.getOperatorState(accumulatorsDescriptor) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala index cf8e442..d5a140a 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala @@ -240,6 +240,59 @@ class SqlITCase extends StreamingWithStateTestBase { assertEquals(expected.sorted, StreamITCase.testResults.sorted) } + @Test + def testUnboundNonPartitionedProcessingWindowWithRange(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + // for sum aggregation ensure that every time the order of each element is consistent + env.setParallelism(1) + + val t1 = env.fromCollection(data).toTable(tEnv).as('a, 'b, 'c) + + tEnv.registerTable("T1", t1) + + val sqlQuery = "SELECT " + + "c, " + + "count(a) OVER (ORDER BY ProcTime() RANGE UNBOUNDED preceding) as cnt1, " + + "sum(a) OVER (ORDER BY ProcTime() RANGE UNBOUNDED preceding) as cnt2 " + + "from T1" + + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "Hello World,7,28", "Hello World,8,36", "Hello World,9,56", + "Hello,1,1", "Hello,2,3", "Hello,3,6", "Hello,4,10", "Hello,5,15", "Hello,6,21") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testUnboundNonPartitionedProcessingWindowWithRow(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val t1 = env.fromCollection(data).toTable(tEnv).as('a, 'b, 'c) + + tEnv.registerTable("T1", t1) + + val sqlQuery = "SELECT " + + "count(a) OVER (ORDER BY ProcTime() ROWS BETWEEN UNBOUNDED preceding AND CURRENT ROW)" + + "from T1" + + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList("1", "2", "3", "4", "5", "6", "7", "8", "9") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + /** * All aggregates must be computed on the same window. */ http://git-wip-us.apache.org/repos/asf/flink/blob/7456d78d/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala index 85bc5a7..2781fb8 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala @@ -185,4 +185,58 @@ class WindowAggregateTest extends TableTestBase { ) streamUtil.verifySql(sql, expected) } + + @Test + def testUnboundNonPartitionedProcessingWindowWithRange() = { + val sql = "SELECT " + + "c, " + + "count(a) OVER (ORDER BY ProcTime() RANGE UNBOUNDED preceding) as cnt1, " + + "sum(a) OVER (ORDER BY ProcTime() RANGE UNBOUNDED preceding) as cnt2 " + + "from MyTable" + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamOverAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "c", "PROCTIME() AS $2") + ), + term("orderBy", "PROCTIME"), + term("range", "BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"), + term("select", "a", "c", "PROCTIME", "COUNT(a) AS w0$o0", "$SUM0(a) AS w0$o1") + ), + term("select", "c", "w0$o0 AS cnt1", "CASE(>(w0$o0, 0)", "CAST(w0$o1), null) AS cnt2") + ) + streamUtil.verifySql(sql, expected) + } + + @Test + def testUnboundNonPartitionedProcessingWindowWithRow() = { + val sql = "SELECT " + + "c, " + + "count(a) OVER (ORDER BY ProcTime() ROWS BETWEEN UNBOUNDED preceding AND " + + "CURRENT ROW) as cnt1 " + + "from MyTable" + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamOverAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "c", "PROCTIME() AS $2") + ), + term("orderBy", "PROCTIME"), + term("rows", "BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"), + term("select", "a", "c", "PROCTIME", "COUNT(a) AS w0$o0") + ), + term("select", "c", "w0$o0 AS $1") + ) + streamUtil.verifySql(sql, expected) + } }