[FLINK-6216] [table] Add non-windowed GroupBy aggregation for streams. This closes #3646.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/24fa1a1c Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/24fa1a1c Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/24fa1a1c Branch: refs/heads/table-retraction Commit: 24fa1a1c2516726a5ecd446bc79b9fb6664bec38 Parents: 6181302 Author: shaoxuan-wang <[email protected]> Authored: Thu Mar 30 03:57:58 2017 +0800 Committer: Fabian Hueske <[email protected]> Committed: Wed May 3 11:27:12 2017 +0200 ---------------------------------------------------------------------- .../flink/table/plan/logical/operators.scala | 3 - .../nodes/datastream/DataStreamAggregate.scala | 272 ------------------- .../datastream/DataStreamGroupAggregate.scala | 140 ++++++++++ .../DataStreamGroupWindowAggregate.scala | 272 +++++++++++++++++++ .../flink/table/plan/rules/FlinkRuleSets.scala | 3 +- .../datastream/DataStreamAggregateRule.scala | 76 ------ .../DataStreamGroupAggregateRule.scala | 79 ++++++ .../DataStreamGroupWindowAggregateRule.scala | 76 ++++++ .../table/runtime/aggregate/AggregateUtil.scala | 63 ++++- .../aggregate/GroupAggProcessFunction.scala | 90 ++++++ .../aggregate/ProcTimeBoundedRangeOver.scala | 2 +- .../scala/batch/table/FieldProjectionTest.scala | 4 +- .../table/api/scala/stream/sql/SqlITCase.scala | 21 ++ .../scala/stream/sql/WindowAggregateTest.scala | 46 ++-- .../scala/stream/table/AggregationsITCase.scala | 180 ------------ .../stream/table/GroupAggregationsITCase.scala | 132 +++++++++ .../stream/table/GroupAggregationsTest.scala | 214 +++++++++++++++ .../table/GroupWindowAggregationsITCase.scala | 180 ++++++++++++ .../scala/stream/table/GroupWindowTest.scala | 56 ++-- .../scala/stream/table/UnsupportedOpsTest.scala | 7 - 20 files changed, 1318 insertions(+), 598 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index f1bb644..66b26ed 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -221,9 +221,6 @@ case class Aggregate( } override def validate(tableEnv: TableEnvironment): LogicalNode = { - if (tableEnv.isInstanceOf[StreamTableEnvironment]) { - failValidation(s"Aggregate on stream tables is currently not supported.") - } val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Aggregate] val groupingExprs = resolvedAggregate.groupingExpressions http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala deleted file mode 100644 index 187773d..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala +++ /dev/null @@ -1,272 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.nodes.datastream - -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rel.core.AggregateCall -import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} -import org.apache.flink.api.java.tuple.Tuple -import org.apache.flink.streaming.api.datastream.{AllWindowedStream, DataStream, KeyedStream, WindowedStream} -import org.apache.flink.streaming.api.windowing.assigners._ -import org.apache.flink.streaming.api.windowing.time.Time -import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} -import org.apache.flink.table.api.StreamTableEnvironment -import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty -import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.codegen.CodeGenerator -import org.apache.flink.table.expressions._ -import org.apache.flink.table.plan.logical._ -import org.apache.flink.table.plan.nodes.CommonAggregate -import org.apache.flink.table.plan.nodes.datastream.DataStreamAggregate._ -import org.apache.flink.table.runtime.aggregate.AggregateUtil._ -import org.apache.flink.table.runtime.aggregate._ -import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval -import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo} -import org.apache.flink.types.Row - -class DataStreamAggregate( - window: LogicalWindow, - namedProperties: Seq[NamedWindowProperty], - cluster: RelOptCluster, - traitSet: RelTraitSet, - inputNode: RelNode, - namedAggregates: Seq[CalcitePair[AggregateCall, String]], - rowRelDataType: RelDataType, - inputType: RelDataType, - grouping: Array[Int]) - extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataStreamRel { - - override def deriveRowType(): RelDataType = rowRelDataType - - override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { - new DataStreamAggregate( - window, - namedProperties, - cluster, - traitSet, - inputs.get(0), - namedAggregates, - getRowType, - inputType, - grouping) - } - - override def toString: String = { - s"Aggregate(${ - if (!grouping.isEmpty) { - s"groupBy: (${groupingToString(inputType, grouping)}), " - } else { - "" - } - }window: ($window), " + - s"select: (${ - aggregationToString( - inputType, - grouping, - getRowType, - namedAggregates, - namedProperties) - }))" - } - - override def explainTerms(pw: RelWriter): RelWriter = { - super.explainTerms(pw) - .itemIf("groupBy", groupingToString(inputType, grouping), !grouping.isEmpty) - .item("window", window) - .item( - "select", aggregationToString( - inputType, - grouping, - getRowType, - namedAggregates, - namedProperties)) - } - - override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = { - - val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) - - val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) - - val aggString = aggregationToString( - inputType, - grouping, - getRowType, - namedAggregates, - namedProperties) - - val keyedAggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " + - s"window: ($window), " + - s"select: ($aggString)" - val nonKeyedAggOpName = s"window: ($window), select: ($aggString)" - - val generator = new CodeGenerator( - tableEnv.getConfig, - false, - inputDS.getType) - - // grouped / keyed aggregation - if (grouping.length > 0) { - val windowFunction = AggregateUtil.createAggregationGroupWindowFunction( - window, - grouping.length, - namedAggregates.size, - rowRelDataType.getFieldCount, - namedProperties) - - val keyedStream = inputDS.keyBy(grouping: _*) - val windowedStream = - createKeyedWindowedStream(window, keyedStream) - .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]] - - val (aggFunction, accumulatorRowType, aggResultRowType) = - AggregateUtil.createDataStreamAggregateFunction( - generator, - namedAggregates, - inputType, - rowRelDataType) - - windowedStream - .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo) - .name(keyedAggOpName) - } - // global / non-keyed aggregation - else { - val windowFunction = AggregateUtil.createAggregationAllWindowFunction( - window, - rowRelDataType.getFieldCount, - namedProperties) - - val windowedStream = - createNonKeyedWindowedStream(window, inputDS) - .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]] - - val (aggFunction, accumulatorRowType, aggResultRowType) = - AggregateUtil.createDataStreamAggregateFunction( - generator, - namedAggregates, - inputType, - rowRelDataType) - - windowedStream - .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo) - .name(nonKeyedAggOpName) - } - } -} - -object DataStreamAggregate { - - - private def createKeyedWindowedStream(groupWindow: LogicalWindow, stream: KeyedStream[Row, Tuple]) - : WindowedStream[Row, Tuple, _ <: DataStreamWindow] = groupWindow match { - - case ProcessingTimeTumblingGroupWindow(_, size) if isTimeInterval(size.resultType) => - stream.window(TumblingProcessingTimeWindows.of(asTime(size))) - - case ProcessingTimeTumblingGroupWindow(_, size) => - stream.countWindow(asCount(size)) - - case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => - stream.window(TumblingEventTimeWindows.of(asTime(size))) - - case EventTimeTumblingGroupWindow(_, _, size) => - // TODO: EventTimeTumblingGroupWindow should sort the stream on event time - // before applying the windowing logic. Otherwise, this would be the same as a - // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException( - "Event-time grouping windows on row intervals are currently not supported.") - - case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) => - stream.window(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide))) - - case ProcessingTimeSlidingGroupWindow(_, size, slide) => - stream.countWindow(asCount(size), asCount(slide)) - - case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) => - stream.window(SlidingEventTimeWindows.of(asTime(size), asTime(slide))) - - case EventTimeSlidingGroupWindow(_, _, size, slide) => - // TODO: EventTimeTumblingGroupWindow should sort the stream on event time - // before applying the windowing logic. Otherwise, this would be the same as a - // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException( - "Event-time grouping windows on row intervals are currently not supported.") - - case ProcessingTimeSessionGroupWindow(_, gap: Expression) => - stream.window(ProcessingTimeSessionWindows.withGap(asTime(gap))) - - case EventTimeSessionGroupWindow(_, _, gap) => - stream.window(EventTimeSessionWindows.withGap(asTime(gap))) - } - - private def createNonKeyedWindowedStream(groupWindow: LogicalWindow, stream: DataStream[Row]) - : AllWindowedStream[Row, _ <: DataStreamWindow] = groupWindow match { - - case ProcessingTimeTumblingGroupWindow(_, size) if isTimeInterval(size.resultType) => - stream.windowAll(TumblingProcessingTimeWindows.of(asTime(size))) - - case ProcessingTimeTumblingGroupWindow(_, size) => - stream.countWindowAll(asCount(size)) - - case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => - stream.windowAll(TumblingEventTimeWindows.of(asTime(size))) - - case EventTimeTumblingGroupWindow(_, _, size) => - // TODO: EventTimeTumblingGroupWindow should sort the stream on event time - // before applying the windowing logic. Otherwise, this would be the same as a - // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException( - "Event-time grouping windows on row intervals are currently not supported.") - - case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) => - stream.windowAll(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide))) - - case ProcessingTimeSlidingGroupWindow(_, size, slide) => - stream.countWindowAll(asCount(size), asCount(slide)) - - case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) => - stream.windowAll(SlidingEventTimeWindows.of(asTime(size), asTime(slide))) - - case EventTimeSlidingGroupWindow(_, _, size, slide) => - // TODO: EventTimeTumblingGroupWindow should sort the stream on event time - // before applying the windowing logic. Otherwise, this would be the same as a - // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException( - "Event-time grouping windows on row intervals are currently not supported.") - - case ProcessingTimeSessionGroupWindow(_, gap) => - stream.windowAll(ProcessingTimeSessionWindows.withGap(asTime(gap))) - - case EventTimeSessionGroupWindow(_, _, gap) => - stream.windowAll(EventTimeSessionWindows.withGap(asTime(gap))) - } - - def asTime(expr: Expression): Time = expr match { - case Literal(value: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) => Time.milliseconds(value) - case _ => throw new IllegalArgumentException() - } - - def asCount(expr: Expression): Long = expr match { - case Literal(value: Long, RowIntervalTypeInfo.INTERVAL_ROWS) => value - case _ => throw new IllegalArgumentException() - } -} - http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala new file mode 100644 index 0000000..955d702 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.nodes.datastream + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.flink.api.java.functions.NullByteKeySelector +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.table.api.StreamTableEnvironment +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.codegen.CodeGenerator +import org.apache.flink.table.runtime.aggregate._ +import org.apache.flink.table.plan.nodes.CommonAggregate +import org.apache.flink.types.Row +import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair + +/** + * + * Flink RelNode for data stream unbounded group aggregate + * + * @param cluster Cluster of the RelNode, represent for an environment of related + * relational expressions during the optimization of a query. + * @param traitSet Trait set of the RelNode + * @param inputNode The input RelNode of aggregation + * @param namedAggregates List of calls to aggregate functions and their output field names + * @param rowRelDataType The type of the rows of the RelNode + * @param inputType The type of the rows of aggregation input RelNode + * @param groupings The position (in the input Row) of the grouping keys + */ +class DataStreamGroupAggregate( + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputNode: RelNode, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + rowRelDataType: RelDataType, + inputType: RelDataType, + groupings: Array[Int]) + extends SingleRel(cluster, traitSet, inputNode) + with CommonAggregate + with DataStreamRel { + + override def deriveRowType() = rowRelDataType + + override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { + new DataStreamGroupAggregate( + cluster, + traitSet, + inputs.get(0), + namedAggregates, + getRowType, + inputType, + groupings) + } + + override def toString: String = { + s"Aggregate(${ + if (!groupings.isEmpty) { + s"groupBy: (${groupingToString(inputType, groupings)}), " + } else { + "" + } + }select:(${aggregationToString(inputType, groupings, getRowType, namedAggregates, Nil)}))" + } + + override def explainTerms(pw: RelWriter): RelWriter = { + super.explainTerms(pw) + .itemIf("groupBy", groupingToString(inputType, groupings), !groupings.isEmpty) + .item("select", aggregationToString(inputType, groupings, getRowType, namedAggregates, Nil)) + } + + override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = { + + val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) + + val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) + + val generator = new CodeGenerator( + tableEnv.getConfig, + false, + inputDS.getType) + + val aggString = aggregationToString( + inputType, + groupings, + getRowType, + namedAggregates, + Nil) + + val keyedAggOpName = s"groupBy: (${groupingToString(inputType, groupings)}), " + + s"select: ($aggString)" + val nonKeyedAggOpName = s"select: ($aggString)" + + val processFunction = AggregateUtil.createGroupAggregateFunction( + generator, + namedAggregates, + inputType, + groupings) + + val result: DataStream[Row] = + // grouped / keyed aggregation + if (groupings.nonEmpty) { + inputDS + .keyBy(groupings: _*) + .process(processFunction) + .returns(rowTypeInfo) + .name(keyedAggOpName) + .asInstanceOf[DataStream[Row]] + } + // global / non-keyed aggregation + else { + inputDS + .keyBy(new NullByteKeySelector[Row]) + .process(processFunction) + .setParallelism(1) + .setMaxParallelism(1) + .returns(rowTypeInfo) + .name(nonKeyedAggOpName) + .asInstanceOf[DataStream[Row]] + } + result + } +} + http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala new file mode 100644 index 0000000..752dbbe --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.nodes.datastream + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.flink.api.java.tuple.Tuple +import org.apache.flink.streaming.api.datastream.{AllWindowedStream, DataStream, KeyedStream, WindowedStream} +import org.apache.flink.streaming.api.windowing.assigners._ +import org.apache.flink.streaming.api.windowing.time.Time +import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} +import org.apache.flink.table.api.StreamTableEnvironment +import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.codegen.CodeGenerator +import org.apache.flink.table.expressions._ +import org.apache.flink.table.plan.logical._ +import org.apache.flink.table.plan.nodes.CommonAggregate +import org.apache.flink.table.plan.nodes.datastream.DataStreamGroupWindowAggregate._ +import org.apache.flink.table.runtime.aggregate.AggregateUtil._ +import org.apache.flink.table.runtime.aggregate._ +import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval +import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo} +import org.apache.flink.types.Row + +class DataStreamGroupWindowAggregate( + window: LogicalWindow, + namedProperties: Seq[NamedWindowProperty], + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputNode: RelNode, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + rowRelDataType: RelDataType, + inputType: RelDataType, + grouping: Array[Int]) + extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataStreamRel { + + override def deriveRowType(): RelDataType = rowRelDataType + + override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { + new DataStreamGroupWindowAggregate( + window, + namedProperties, + cluster, + traitSet, + inputs.get(0), + namedAggregates, + getRowType, + inputType, + grouping) + } + + override def toString: String = { + s"Aggregate(${ + if (!grouping.isEmpty) { + s"groupBy: (${groupingToString(inputType, grouping)}), " + } else { + "" + } + }window: ($window), " + + s"select: (${ + aggregationToString( + inputType, + grouping, + getRowType, + namedAggregates, + namedProperties) + }))" + } + + override def explainTerms(pw: RelWriter): RelWriter = { + super.explainTerms(pw) + .itemIf("groupBy", groupingToString(inputType, grouping), !grouping.isEmpty) + .item("window", window) + .item( + "select", aggregationToString( + inputType, + grouping, + getRowType, + namedAggregates, + namedProperties)) + } + + override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = { + + val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) + + val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) + + val aggString = aggregationToString( + inputType, + grouping, + getRowType, + namedAggregates, + namedProperties) + + val keyedAggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " + + s"window: ($window), " + + s"select: ($aggString)" + val nonKeyedAggOpName = s"window: ($window), select: ($aggString)" + + val generator = new CodeGenerator( + tableEnv.getConfig, + false, + inputDS.getType) + + // grouped / keyed aggregation + if (grouping.length > 0) { + val windowFunction = AggregateUtil.createAggregationGroupWindowFunction( + window, + grouping.length, + namedAggregates.size, + rowRelDataType.getFieldCount, + namedProperties) + + val keyedStream = inputDS.keyBy(grouping: _*) + val windowedStream = + createKeyedWindowedStream(window, keyedStream) + .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]] + + val (aggFunction, accumulatorRowType, aggResultRowType) = + AggregateUtil.createDataStreamAggregateFunction( + generator, + namedAggregates, + inputType, + rowRelDataType) + + windowedStream + .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo) + .name(keyedAggOpName) + } + // global / non-keyed aggregation + else { + val windowFunction = AggregateUtil.createAggregationAllWindowFunction( + window, + rowRelDataType.getFieldCount, + namedProperties) + + val windowedStream = + createNonKeyedWindowedStream(window, inputDS) + .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]] + + val (aggFunction, accumulatorRowType, aggResultRowType) = + AggregateUtil.createDataStreamAggregateFunction( + generator, + namedAggregates, + inputType, + rowRelDataType) + + windowedStream + .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo) + .name(nonKeyedAggOpName) + } + } +} + +object DataStreamGroupWindowAggregate { + + + private def createKeyedWindowedStream(groupWindow: LogicalWindow, stream: KeyedStream[Row, Tuple]) + : WindowedStream[Row, Tuple, _ <: DataStreamWindow] = groupWindow match { + + case ProcessingTimeTumblingGroupWindow(_, size) if isTimeInterval(size.resultType) => + stream.window(TumblingProcessingTimeWindows.of(asTime(size))) + + case ProcessingTimeTumblingGroupWindow(_, size) => + stream.countWindow(asCount(size)) + + case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => + stream.window(TumblingEventTimeWindows.of(asTime(size))) + + case EventTimeTumblingGroupWindow(_, _, size) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") + + case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) => + stream.window(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide))) + + case ProcessingTimeSlidingGroupWindow(_, size, slide) => + stream.countWindow(asCount(size), asCount(slide)) + + case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) => + stream.window(SlidingEventTimeWindows.of(asTime(size), asTime(slide))) + + case EventTimeSlidingGroupWindow(_, _, size, slide) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") + + case ProcessingTimeSessionGroupWindow(_, gap: Expression) => + stream.window(ProcessingTimeSessionWindows.withGap(asTime(gap))) + + case EventTimeSessionGroupWindow(_, _, gap) => + stream.window(EventTimeSessionWindows.withGap(asTime(gap))) + } + + private def createNonKeyedWindowedStream(groupWindow: LogicalWindow, stream: DataStream[Row]) + : AllWindowedStream[Row, _ <: DataStreamWindow] = groupWindow match { + + case ProcessingTimeTumblingGroupWindow(_, size) if isTimeInterval(size.resultType) => + stream.windowAll(TumblingProcessingTimeWindows.of(asTime(size))) + + case ProcessingTimeTumblingGroupWindow(_, size) => + stream.countWindowAll(asCount(size)) + + case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => + stream.windowAll(TumblingEventTimeWindows.of(asTime(size))) + + case EventTimeTumblingGroupWindow(_, _, size) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") + + case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) => + stream.windowAll(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide))) + + case ProcessingTimeSlidingGroupWindow(_, size, slide) => + stream.countWindowAll(asCount(size), asCount(slide)) + + case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) => + stream.windowAll(SlidingEventTimeWindows.of(asTime(size), asTime(slide))) + + case EventTimeSlidingGroupWindow(_, _, size, slide) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") + + case ProcessingTimeSessionGroupWindow(_, gap) => + stream.windowAll(ProcessingTimeSessionWindows.withGap(asTime(gap))) + + case EventTimeSessionGroupWindow(_, _, gap) => + stream.windowAll(EventTimeSessionWindows.withGap(asTime(gap))) + } + + def asTime(expr: Expression): Time = expr match { + case Literal(value: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) => Time.milliseconds(value) + case _ => throw new IllegalArgumentException() + } + + def asCount(expr: Expression): Long = expr match { + case Literal(value: Long, RowIntervalTypeInfo.INTERVAL_ROWS) => value + case _ => throw new IllegalArgumentException() + } +} + http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala index 0bee4e5..c16b469 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala @@ -173,8 +173,9 @@ object FlinkRuleSets { */ val DATASTREAM_OPT_RULES: RuleSet = RuleSets.ofList( // translate to DataStream nodes + DataStreamGroupAggregateRule.INSTANCE, DataStreamOverAggregateRule.INSTANCE, - DataStreamAggregateRule.INSTANCE, + DataStreamGroupWindowAggregateRule.INSTANCE, DataStreamCalcRule.INSTANCE, DataStreamScanRule.INSTANCE, DataStreamUnionRule.INSTANCE, http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala deleted file mode 100644 index f011b66..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.rules.datastream - -import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} -import org.apache.calcite.rel.RelNode -import org.apache.calcite.rel.convert.ConverterRule -import org.apache.flink.table.api.TableException -import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.nodes.datastream.DataStreamAggregate -import org.apache.flink.table.plan.nodes.logical.FlinkLogicalWindowAggregate - -import scala.collection.JavaConversions._ - -class DataStreamAggregateRule - extends ConverterRule( - classOf[FlinkLogicalWindowAggregate], - FlinkConventions.LOGICAL, - FlinkConventions.DATASTREAM, - "DataStreamAggregateRule") { - - override def matches(call: RelOptRuleCall): Boolean = { - val agg: FlinkLogicalWindowAggregate = call.rel(0).asInstanceOf[FlinkLogicalWindowAggregate] - - // check if we have distinct aggregates - val distinctAggs = agg.getAggCallList.exists(_.isDistinct) - if (distinctAggs) { - throw TableException("DISTINCT aggregates are currently not supported.") - } - - // check if we have grouping sets - val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet - if (groupSets || agg.indicator) { - throw TableException("GROUPING SETS are currently not supported.") - } - - !distinctAggs && !groupSets && !agg.indicator - } - - override def convert(rel: RelNode): RelNode = { - val agg: FlinkLogicalWindowAggregate = rel.asInstanceOf[FlinkLogicalWindowAggregate] - val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM) - val convInput: RelNode = RelOptRule.convert(agg.getInput, FlinkConventions.DATASTREAM) - - new DataStreamAggregate( - agg.getWindow, - agg.getNamedProperties, - rel.getCluster, - traitSet, - convInput, - agg.getNamedAggCalls, - rel.getRowType, - agg.getInput.getRowType, - agg.getGroupSet.toArray) - } - } - -object DataStreamAggregateRule { - val INSTANCE: RelOptRule = new DataStreamAggregateRule -} http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupAggregateRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupAggregateRule.scala new file mode 100644 index 0000000..a65c378 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupAggregateRule.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.datastream + +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.calcite.rel.logical.LogicalAggregate +import org.apache.flink.table.api.TableException +import org.apache.flink.table.plan.nodes.FlinkConventions +import org.apache.flink.table.plan.nodes.datastream.DataStreamGroupAggregate +import org.apache.flink.table.plan.nodes.logical.FlinkLogicalAggregate + +import scala.collection.JavaConversions._ + +/** + * Rule to convert a [[LogicalAggregate]] into a [[DataStreamGroupAggregate]]. + */ +class DataStreamGroupAggregateRule + extends ConverterRule( + classOf[FlinkLogicalAggregate], + FlinkConventions.LOGICAL, + FlinkConventions.DATASTREAM, + "DataStreamGroupAggregateRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val agg: FlinkLogicalAggregate = call.rel(0).asInstanceOf[FlinkLogicalAggregate] + + // check if we have distinct aggregates + val distinctAggs = agg.getAggCallList.exists(_.isDistinct) + if (distinctAggs) { + throw TableException("DISTINCT aggregates are currently not supported.") + } + + // check if we have grouping sets + val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet + if (groupSets || agg.indicator) { + throw TableException("GROUPING SETS are currently not supported.") + } + + !distinctAggs && !groupSets && !agg.indicator + } + + override def convert(rel: RelNode): RelNode = { + val agg: FlinkLogicalAggregate = rel.asInstanceOf[FlinkLogicalAggregate] + val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM) + val convInput: RelNode = RelOptRule.convert(agg.getInput, FlinkConventions.DATASTREAM) + + new DataStreamGroupAggregate( + rel.getCluster, + traitSet, + convInput, + agg.getNamedAggCalls, + rel.getRowType, + agg.getInput.getRowType, + agg.getGroupSet.toArray) + } +} + +object DataStreamGroupAggregateRule { + val INSTANCE: RelOptRule = new DataStreamGroupAggregateRule +} + http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupWindowAggregateRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupWindowAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupWindowAggregateRule.scala new file mode 100644 index 0000000..fdf44a6 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupWindowAggregateRule.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.datastream + +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.flink.table.api.TableException +import org.apache.flink.table.plan.nodes.FlinkConventions +import org.apache.flink.table.plan.nodes.datastream.DataStreamGroupWindowAggregate +import org.apache.flink.table.plan.nodes.logical.FlinkLogicalWindowAggregate + +import scala.collection.JavaConversions._ + +class DataStreamGroupWindowAggregateRule + extends ConverterRule( + classOf[FlinkLogicalWindowAggregate], + FlinkConventions.LOGICAL, + FlinkConventions.DATASTREAM, + "DataStreamGroupWindowAggregateRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val agg: FlinkLogicalWindowAggregate = call.rel(0).asInstanceOf[FlinkLogicalWindowAggregate] + + // check if we have distinct aggregates + val distinctAggs = agg.getAggCallList.exists(_.isDistinct) + if (distinctAggs) { + throw TableException("DISTINCT aggregates are currently not supported.") + } + + // check if we have grouping sets + val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet + if (groupSets || agg.indicator) { + throw TableException("GROUPING SETS are currently not supported.") + } + + !distinctAggs && !groupSets && !agg.indicator + } + + override def convert(rel: RelNode): RelNode = { + val agg: FlinkLogicalWindowAggregate = rel.asInstanceOf[FlinkLogicalWindowAggregate] + val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM) + val convInput: RelNode = RelOptRule.convert(agg.getInput, FlinkConventions.DATASTREAM) + + new DataStreamGroupWindowAggregate( + agg.getWindow, + agg.getNamedProperties, + rel.getCluster, + traitSet, + convInput, + agg.getNamedAggCalls, + rel.getRowType, + agg.getInput.getRowType, + agg.getGroupSet.toArray) + } + } + +object DataStreamGroupWindowAggregateRule { + val INSTANCE: RelOptRule = new DataStreamGroupWindowAggregateRule +} http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index a82f383..f93d870 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -81,8 +81,7 @@ object AggregateUtil { inputType, needRetract) - val aggregationStateType: RowTypeInfo = - createDataSetAggregateBufferDataType(Array(), aggregates, inputType) + val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates) val forwardMapping = (0 until inputType.getFieldCount).toArray val aggMapping = aggregates.indices.map(x => x + inputType.getFieldCount).toArray @@ -132,7 +131,55 @@ object AggregateUtil { } /** - * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for + * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for group (without + * window) aggregate to evaluate final aggregate value. + * + * @param generator code generator instance + * @param namedAggregates List of calls to aggregate functions and their output field names + * @param inputType Input row type + * @param groupings the position (in the input Row) of the grouping keys + * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]] + */ + private[flink] def createGroupAggregateFunction( + generator: CodeGenerator, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + groupings: Array[Int]): ProcessFunction[Row, Row] = { + + val (aggFields, aggregates) = + transformToAggregateFunctions( + namedAggregates.map(_.getKey), + inputType, + needRetraction = false) + val aggMapping = aggregates.indices.map(_ + groupings.length).toArray + + val outputArity = groupings.length + aggregates.length + + val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates) + + val genFunction = generator.generateAggregations( + "NonWindowedAggregationHelper", + generator, + inputType, + aggregates, + aggFields, + aggMapping, + partialResults = false, + groupings, + None, + None, + outputArity, + needRetract = false, + needMerge = false + ) + + new GroupAggProcessFunction( + genFunction, + aggregationStateType) + } + + /** + * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for ROWS clause * bounded OVER window to evaluate final aggregate value. * * @param generator code generator instance @@ -253,7 +300,7 @@ object AggregateUtil { needRetract) val mapReturnType: RowTypeInfo = - createDataSetAggregateBufferDataType( + createRowTypeForKeysAndAggregates( groupings, aggregates, inputType, @@ -355,7 +402,7 @@ object AggregateUtil { inputType, needRetract) - val returnType: RowTypeInfo = createDataSetAggregateBufferDataType( + val returnType: RowTypeInfo = createRowTypeForKeysAndAggregates( groupings, aggregates, inputType, @@ -617,7 +664,7 @@ object AggregateUtil { window match { case EventTimeSessionGroupWindow(_, _, gap) => val combineReturnType: RowTypeInfo = - createDataSetAggregateBufferDataType( + createRowTypeForKeysAndAggregates( groupings, aggregates, inputType, @@ -689,7 +736,7 @@ object AggregateUtil { case EventTimeSessionGroupWindow(_, _, gap) => val combineReturnType: RowTypeInfo = - createDataSetAggregateBufferDataType( + createRowTypeForKeysAndAggregates( groupings, aggregates, inputType, @@ -1297,7 +1344,7 @@ object AggregateUtil { aggTypes } - private def createDataSetAggregateBufferDataType( + private def createRowTypeForKeysAndAggregates( groupings: Array[Int], aggregates: Array[TableAggregateFunction[_, _]], inputType: RelDataType, http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala new file mode 100644 index 0000000..81c900c --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.types.Row +import org.apache.flink.util.Collector +import org.apache.flink.api.common.state.ValueStateDescriptor +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.common.state.ValueState +import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.slf4j.LoggerFactory + +/** + * Aggregate Function used for the groupby (without window) aggregate + * + * @param genAggregations Generated aggregate helper function + * @param aggregationStateType The row type info of aggregation + */ +class GroupAggProcessFunction( + private val genAggregations: GeneratedAggregationsFunction, + private val aggregationStateType: RowTypeInfo) + extends ProcessFunction[Row, Row] + with Compiler[GeneratedAggregations] { + + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ + + private var output: Row = _ + private var state: ValueState[Row] = _ + + override def open(config: Configuration) { + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getRuntimeContext.getUserCodeClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + output = function.createOutputRow() + + val stateDescriptor: ValueStateDescriptor[Row] = + new ValueStateDescriptor[Row]("GroupAggregateState", aggregationStateType) + state = getRuntimeContext.getState(stateDescriptor) + } + + override def processElement( + input: Row, + ctx: ProcessFunction[Row, Row]#Context, + out: Collector[Row]): Unit = { + + // get accumulators + var accumulators = state.value() + if (null == accumulators) { + accumulators = function.createAccumulators() + } + + // Set group keys value to the final output + function.setForwardedFields(input, output) + + // accumulate new input row + function.accumulate(accumulators, input) + + // set aggregation results to output + function.setAggregationResults(accumulators, output) + + // update accumulators + state.update(accumulators) + + out.collect(output) + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala index 7f87e50..b63eb81 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala @@ -38,7 +38,7 @@ import org.slf4j.LoggerFactory * Process Function used for the aggregate in bounded proc-time OVER window * [[org.apache.flink.streaming.api.datastream.DataStream]] * - * @param genAggregations Generated aggregate helper function + * @param genAggregations Generated aggregate helper function * @param precedingTimeBoundary Is used to indicate the processing time boundaries * @param aggregatesTypeInfo row type info of aggregation * @param inputType row type info of input row http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala index 6ebfec0..033019b 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala @@ -231,7 +231,7 @@ class FieldProjectionTest extends TableTestBase { val expected = unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -259,7 +259,7 @@ class FieldProjectionTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala index 67d13b0..f7bdccf 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala @@ -47,6 +47,27 @@ class SqlITCase extends StreamingWithStateTestBase { (8L, 8, "Hello World"), (20L, 20, "Hello World")) + /** test unbounded groupby (without window) **/ + @Test + def testUnboundedGroupby(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val sqlQuery = "SELECT b, COUNT(a) FROM MyTable GROUP BY b" + + val t = StreamTestData.getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c) + tEnv.registerTable("MyTable", t) + + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList("1,1", "2,1", "2,2") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + /** test selection **/ @Test def testSelectExpressionFromTable(): Unit = { http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala index 578a6a8..01263a3 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala @@ -61,7 +61,7 @@ class WindowAggregateTest extends TableTestBase { val sqlQuery = "SELECT a, AVG(c) OVER (PARTITION BY a ORDER BY procTime()" + "RANGE BETWEEN INTERVAL '2' HOUR PRECEDING AND CURRENT ROW) AS avgA " + "FROM MyTable" - val expected = + val expected = unaryNode( "DataStreamCalc", unaryNode( @@ -71,7 +71,7 @@ class WindowAggregateTest extends TableTestBase { streamTableNode(0), term("select", "a", "c", "PROCTIME() AS $2") ), - term("partitionBy","a"), + term("partitionBy", "a"), term("orderBy", "PROCTIME"), term("range", "BETWEEN 7200000 PRECEDING AND CURRENT ROW"), term("select", "a", "c", "PROCTIME", "COUNT(c) AS w0$o0", "$SUM0(c) AS w0$o1") @@ -83,6 +83,27 @@ class WindowAggregateTest extends TableTestBase { } @Test + def testGroupbyWithoutWindow() = { + val sql = "SELECT COUNT(a) FROM MyTable GROUP BY b" + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "b", "a") + ), + term("groupBy", "b"), + term("select", "b", "COUNT(a) AS EXPR$0") + ), + term("select", "EXPR$0") + ) + streamUtil.verifySql(sql, expected) + } + + @Test def testTumbleFunction() = { val sql = @@ -96,7 +117,7 @@ class WindowAggregateTest extends TableTestBase { unaryNode( "DataStreamCalc", unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -122,7 +143,7 @@ class WindowAggregateTest extends TableTestBase { unaryNode( "DataStreamCalc", unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -150,7 +171,7 @@ class WindowAggregateTest extends TableTestBase { unaryNode( "DataStreamCalc", unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -200,21 +221,6 @@ class WindowAggregateTest extends TableTestBase { streamUtil.verifySql(sql, "n/a") } - @Test(expected = classOf[TableException]) - def testMultiWindow() = { - val sql = "SELECT COUNT(*) FROM MyTable GROUP BY " + - "FLOOR(rowtime() TO HOUR), FLOOR(rowtime() TO MINUTE)" - val expected = "" - streamUtil.verifySql(sql, expected) - } - - @Test(expected = classOf[TableException]) - def testInvalidWindowExpression() = { - val sql = "SELECT COUNT(*) FROM MyTable GROUP BY FLOOR(localTimestamp TO HOUR)" - val expected = "" - streamUtil.verifySql(sql, expected) - } - @Test def testUnboundPartitionedProcessingWindowWithRange() = { val sql = "SELECT " + http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala deleted file mode 100644 index 9f366a8..0000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.api.scala.stream.table - -import org.apache.flink.api.scala._ -import org.apache.flink.streaming.api.TimeCharacteristic -import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks -import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment -import org.apache.flink.streaming.api.watermark.Watermark -import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase -import org.apache.flink.table.api.TableEnvironment -import org.apache.flink.table.api.scala._ -import org.apache.flink.table.api.scala.stream.table.AggregationsITCase.TimestampAndWatermarkWithOffset -import org.apache.flink.table.api.scala.stream.utils.StreamITCase -import org.apache.flink.types.Row -import org.junit.Assert._ -import org.junit.Test - -import scala.collection.mutable - -/** - * We only test some aggregations until better testing of constructed DataStream - * programs is possible. - */ -class AggregationsITCase extends StreamingMultipleProgramsTestBase { - - val data = List( - (1L, 1, "Hi"), - (2L, 2, "Hello"), - (4L, 2, "Hello"), - (8L, 3, "Hello world"), - (16L, 3, "Hello world")) - - @Test - def testProcessingTimeSlidingGroupWindowOverCount(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.testResults = mutable.MutableList() - - val stream = env.fromCollection(data) - val table = stream.toTable(tEnv, 'long, 'int, 'string) - - val windowedTable = table - .window(Slide over 2.rows every 1.rows as 'w) - .groupBy('w, 'string) - .select('string, 'int.count, 'int.avg) - - val results = windowedTable.toDataStream[Row] - results.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = Seq("Hello world,1,3", "Hello world,2,3", "Hello,1,2", "Hello,2,2", "Hi,1,1") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testEventTimeSessionGroupWindowOverTime(): Unit = { - //To verify the "merge" functionality, we create this test with the following characteristics: - // 1. set the Parallelism to 1, and have the test data out of order - // 2. create a waterMark with 10ms offset to delay the window emission by 10ms - val sessionWindowTestdata = List( - (1L, 1, "Hello"), - (2L, 2, "Hello"), - (8L, 8, "Hello"), - (9L, 9, "Hello World"), - (4L, 4, "Hello"), - (16L, 16, "Hello")) - - val env = StreamExecutionEnvironment.getExecutionEnvironment - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) - env.setParallelism(1) - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.testResults = mutable.MutableList() - - val stream = env - .fromCollection(sessionWindowTestdata) - .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(10L)) - val table = stream.toTable(tEnv, 'long, 'int, 'string) - - val windowedTable = table - .window(Session withGap 5.milli on 'rowtime as 'w) - .groupBy('w, 'string) - .select('string, 'int.count, 'int.sum) - - val results = windowedTable.toDataStream[Row] - results.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = Seq("Hello World,1,9", "Hello,1,16", "Hello,4,15") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testAllProcessingTimeTumblingGroupWindowOverCount(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.testResults = mutable.MutableList() - - val stream = env.fromCollection(data) - val table = stream.toTable(tEnv, 'long, 'int, 'string) - - val windowedTable = table - .window(Tumble over 2.rows as 'w) - .groupBy('w) - .select('int.count) - - val results = windowedTable.toDataStream[Row] - results.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = Seq("2", "2") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testEventTimeTumblingWindow(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.testResults = mutable.MutableList() - - val stream = env - .fromCollection(data) - .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(0L)) - val table = stream.toTable(tEnv, 'long, 'int, 'string) - - val windowedTable = table - .window(Tumble over 5.milli on 'rowtime as 'w) - .groupBy('w, 'string) - .select('string, 'int.count, 'int.avg, 'int.min, 'int.max, 'int.sum, 'w.start, 'w.end) - - val results = windowedTable.toDataStream[Row] - results.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = Seq( - "Hello world,1,3,3,3,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01", - "Hello world,1,3,3,3,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02", - "Hello,2,2,2,2,4,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005", - "Hi,1,1,1,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } -} - -object AggregationsITCase { - class TimestampAndWatermarkWithOffset( - offset: Long) extends AssignerWithPunctuatedWatermarks[(Long, Int, String)] { - - override def checkAndGetNextWatermark( - lastElement: (Long, Int, String), - extractedTimestamp: Long) - : Watermark = { - new Watermark(extractedTimestamp - offset) - } - - override def extractTimestamp( - element: (Long, Int, String), - previousElementTimestamp: Long): Long = { - element._1 - } - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala new file mode 100644 index 0000000..271e90b --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.stream.table + +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase} +import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.types.Row +import org.junit.Assert.assertEquals +import org.junit.Test + +import scala.collection.mutable + +/** + * Tests of groupby (without window) aggregations + */ +class GroupAggregationsITCase extends StreamingWithStateTestBase { + + @Test + def testNonKeyedGroupAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) + .select('a.sum, 'b.sum) + + val results = t.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,1", "3,3", "6,5", "10,8", "15,11", "21,14", "28,18", "36,22", "45,26", "55,30", "66,35", + "78,40", "91,45", "105,50", "120,55", "136,61", "153,67", "171,73", "190,79", "210,85", + "231,91") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testGroupAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) + .groupBy('b) + .select('b, 'a.sum) + + val results = t.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,1", "2,2", "2,5", "3,4", "3,9", "3,15", "4,7", "4,15", + "4,24", "4,34", "5,11", "5,23", "5,36", "5,50", "5,65", "6,16", "6,33", "6,51", "6,70", + "6,90", "6,111") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testDoubleGroupAggregation(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) + .groupBy('b) + .select('a.sum as 'd, 'b) + .groupBy('b, 'd) + .select('b) + + val results = t.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1", + "2", "2", + "3", "3", "3", + "4", "4", "4", "4", + "5", "5", "5", "5", "5", + "6", "6", "6", "6", "6", "6") + + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testGroupAggregateWithExpression(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e) + .groupBy('e, 'b % 3) + .select('c.min, 'e, 'a.avg, 'd.count) + + val results = t.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "0,1,1,1", "1,2,2,1", "2,1,2,1", "3,2,3,1", "1,2,2,2", + "5,3,3,1", "3,2,3,2", "7,1,4,1", "2,1,3,2", "3,2,3,3", "7,1,4,2", "5,3,4,2", "12,3,5,1", + "1,2,3,3", "14,2,5,1") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala new file mode 100644 index 0000000..520592c --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.stream.table + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.utils.TableTestBase +import org.junit.Test +import org.apache.flink.table.api.scala._ +import org.apache.flink.api.scala._ +import org.apache.flink.table.utils.TableTestUtil._ + +class GroupAggregationsTest extends TableTestBase { + + @Test(expected = classOf[ValidationException]) + def testGroupingOnNonExistentField(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val ds = table + // must fail. '_foo is not a valid field + .groupBy('_foo) + .select('a.avg) + } + + @Test(expected = classOf[ValidationException]) + def testGroupingInvalidSelection(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val ds = table + .groupBy('a, 'b) + // must fail. 'c is not a grouping key or aggregation + .select('c) + } + + @Test + def testGroupbyWithoutWindow() = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .groupBy('b) + .select('a.count) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "b") + ), + term("groupBy", "b"), + term("select", "b", "COUNT(a) AS TMP_0") + ), + term("select", "TMP_0") + ) + util.verifyTable(resultTable, expected) + } + + + @Test + def testGroupAggregateWithConstant1(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .select('a, 4 as 'four, 'b) + .groupBy('four, 'a) + .select('four, 'b.sum) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "4 AS four", "b", "a") + ), + term("groupBy", "four", "a"), + term("select", "four", "a", "SUM(b) AS TMP_0") + ), + term("select", "4 AS four", "TMP_0") + ) + util.verifyTable(resultTable, expected) + } + + @Test + def testGroupAggregateWithConstant2(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .select('b, 4 as 'four, 'a) + .groupBy('b, 'four) + .select('four, 'a.sum) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "4 AS four", "a", "b") + ), + term("groupBy", "four", "b"), + term("select", "four", "b", "SUM(a) AS TMP_0") + ), + term("select", "4 AS four", "TMP_0") + ) + util.verifyTable(resultTable, expected) + } + + @Test + def testGroupAggregateWithExpressionInSelect(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .select('a as 'a, 'b % 3 as 'd, 'c as 'c) + .groupBy('d) + .select('c.min, 'a.avg) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "MOD(b, 3) AS d", "c") + ), + term("groupBy", "d"), + term("select", "d", "MIN(c) AS TMP_0", "AVG(a) AS TMP_1") + ), + term("select", "TMP_0", "TMP_1") + ) + util.verifyTable(resultTable, expected) + } + + @Test + def testGroupAggregateWithFilter(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .groupBy('b) + .select('b, 'a.sum) + .where('b === 2) + + val expected = + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "b", "a"), + term("where", "=(b, 2)") + ), + term("groupBy", "b"), + term("select", "b", "SUM(a) AS TMP_0") + ) + util.verifyTable(resultTable, expected) + } + + @Test + def testGroupAggregateWithAverage(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .groupBy('b) + .select('b, 'a.cast(BasicTypeInfo.DOUBLE_TYPE_INFO).avg) + + val expected = + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "b", "a", "CAST(a) AS a0") + ), + term("groupBy", "b"), + term("select", "b", "AVG(a0) AS TMP_0") + ) + + util.verifyTable(resultTable, expected) + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/24fa1a1c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowAggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowAggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowAggregationsITCase.scala new file mode 100644 index 0000000..3e3c57d --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowAggregationsITCase.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.stream.table + +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase +import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.table.GroupWindowAggregationsITCase.TimestampAndWatermarkWithOffset +import org.apache.flink.table.api.scala.stream.utils.StreamITCase +import org.apache.flink.types.Row +import org.junit.Assert._ +import org.junit.Test + +import scala.collection.mutable + +/** + * We only test some aggregations until better testing of constructed DataStream + * programs is possible. + */ +class GroupWindowAggregationsITCase extends StreamingMultipleProgramsTestBase { + + val data = List( + (1L, 1, "Hi"), + (2L, 2, "Hello"), + (4L, 2, "Hello"), + (8L, 3, "Hello world"), + (16L, 3, "Hello world")) + + @Test + def testProcessingTimeSlidingGroupWindowOverCount(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env.fromCollection(data) + val table = stream.toTable(tEnv, 'long, 'int, 'string) + + val windowedTable = table + .window(Slide over 2.rows every 1.rows as 'w) + .groupBy('w, 'string) + .select('string, 'int.count, 'int.avg) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq("Hello world,1,3", "Hello world,2,3", "Hello,1,2", "Hello,2,2", "Hi,1,1") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testEventTimeSessionGroupWindowOverTime(): Unit = { + //To verify the "merge" functionality, we create this test with the following characteristics: + // 1. set the Parallelism to 1, and have the test data out of order + // 2. create a waterMark with 10ms offset to delay the window emission by 10ms + val sessionWindowTestdata = List( + (1L, 1, "Hello"), + (2L, 2, "Hello"), + (8L, 8, "Hello"), + (9L, 9, "Hello World"), + (4L, 4, "Hello"), + (16L, 16, "Hello")) + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env + .fromCollection(sessionWindowTestdata) + .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(10L)) + val table = stream.toTable(tEnv, 'long, 'int, 'string) + + val windowedTable = table + .window(Session withGap 5.milli on 'rowtime as 'w) + .groupBy('w, 'string) + .select('string, 'int.count, 'int.sum) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq("Hello World,1,9", "Hello,1,16", "Hello,4,15") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testAllProcessingTimeTumblingGroupWindowOverCount(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env.fromCollection(data) + val table = stream.toTable(tEnv, 'long, 'int, 'string) + + val windowedTable = table + .window(Tumble over 2.rows as 'w) + .groupBy('w) + .select('int.count) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq("2", "2") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testEventTimeTumblingWindow(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env + .fromCollection(data) + .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(0L)) + val table = stream.toTable(tEnv, 'long, 'int, 'string) + + val windowedTable = table + .window(Tumble over 5.milli on 'rowtime as 'w) + .groupBy('w, 'string) + .select('string, 'int.count, 'int.avg, 'int.min, 'int.max, 'int.sum, 'w.start, 'w.end) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello world,1,3,3,3,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01", + "Hello world,1,3,3,3,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02", + "Hello,2,2,2,2,4,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005", + "Hi,1,1,1,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } +} + +object GroupWindowAggregationsITCase { + class TimestampAndWatermarkWithOffset( + offset: Long) extends AssignerWithPunctuatedWatermarks[(Long, Int, String)] { + + override def checkAndGetNextWatermark( + lastElement: (Long, Int, String), + extractedTimestamp: Long) + : Watermark = { + new Watermark(extractedTimestamp - offset) + } + + override def extractTimestamp( + element: (Long, Int, String), + previousElementTimestamp: Long): Long = { + element._1 + } + } +}
