[FLINK-4997] Add ProcessWindowFunction support for .aggregate()
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/fe2a3016 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/fe2a3016 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/fe2a3016 Branch: refs/heads/master Commit: fe2a3016f98e45d0c94a3fa1ed8c17b89a516859 Parents: 4f047e1 Author: Aljoscha Krettek <[email protected]> Authored: Tue Feb 7 14:38:25 2017 +0100 Committer: Aljoscha Krettek <[email protected]> Committed: Fri Feb 17 17:15:51 2017 +0100 ---------------------------------------------------------------------- .../api/datastream/WindowedStream.java | 128 +++++++++++++++++++ .../InternalAggregateProcessWindowFunction.java | 84 ++++++++++++ .../functions/InternalWindowFunctionTest.java | 104 +++++++++++++++ .../streaming/api/scala/WindowedStream.scala | 30 +++++ 4 files changed, 346 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/fe2a3016/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java index 45eaae5..6809df0 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java @@ -64,6 +64,7 @@ import org.apache.flink.streaming.api.windowing.windows.Window; import org.apache.flink.streaming.runtime.operators.windowing.AccumulatingProcessingTimeWindowOperator; import org.apache.flink.streaming.runtime.operators.windowing.AggregatingProcessingTimeWindowOperator; import org.apache.flink.streaming.runtime.operators.windowing.EvictingWindowOperator; +import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalAggregateProcessWindowFunction; import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableProcessWindowFunction; import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableWindowFunction; import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalSingleValueProcessWindowFunction; @@ -906,6 +907,133 @@ public class WindowedStream<T, K, W extends Window> { return input.transform(opName, resultType, operator); } + /** + * Applies the given window function to each window. The window function is called for each + * evaluation of the window for each key individually. The output of the window function is + * interpreted as a regular non-windowed stream. + * + * <p>Arriving data is incrementally aggregated using the given aggregate function. This means + * that the window function typically has only a single value to process when called. + * + * @param aggFunction The aggregate function that is used for incremental aggregation. + * @param windowFunction The window function. + * + * @return The data stream that is the result of applying the window function to the window. + * + * @param <ACC> The type of the AggregateFunction's accumulator + * @param <V> The type of AggregateFunction's result, and the WindowFunction's input + * @param <R> The type of the elements in the resulting stream, equal to the + * WindowFunction's result type + */ + @PublicEvolving + public <ACC, V, R> SingleOutputStreamOperator<R> aggregate( + AggregateFunction<T, ACC, V> aggFunction, + ProcessWindowFunction<V, R, K, W> windowFunction) { + + checkNotNull(aggFunction, "aggFunction"); + checkNotNull(windowFunction, "windowFunction"); + + TypeInformation<ACC> accumulatorType = TypeExtractor.getAggregateFunctionAccumulatorType( + aggFunction, input.getType(), null, false); + + TypeInformation<V> aggResultType = TypeExtractor.getAggregateFunctionReturnType( + aggFunction, input.getType(), null, false); + + TypeInformation<R> resultType = TypeExtractor.getUnaryOperatorReturnType( + windowFunction, ProcessWindowFunction.class, true, true, aggResultType, null, false); + + return aggregate(aggFunction, windowFunction, accumulatorType, aggResultType, resultType); + } + + /** + * Applies the given window function to each window. The window function is called for each + * evaluation of the window for each key individually. The output of the window function is + * interpreted as a regular non-windowed stream. + * + * <p>Arriving data is incrementally aggregated using the given aggregate function. This means + * that the window function typically has only a single value to process when called. + * + * @param aggregateFunction The aggregation function that is used for incremental aggregation. + * @param windowFunction The window function. + * @param accumulatorType Type information for the internal accumulator type of the aggregation function + * @param resultType Type information for the result type of the window function + * + * @return The data stream that is the result of applying the window function to the window. + * + * @param <ACC> The type of the AggregateFunction's accumulator + * @param <V> The type of AggregateFunction's result, and the WindowFunction's input + * @param <R> The type of the elements in the resulting stream, equal to the + * WindowFunction's result type + */ + @PublicEvolving + public <ACC, V, R> SingleOutputStreamOperator<R> aggregate( + AggregateFunction<T, ACC, V> aggregateFunction, + ProcessWindowFunction<V, R, K, W> windowFunction, + TypeInformation<ACC> accumulatorType, + TypeInformation<V> aggregateResultType, + TypeInformation<R> resultType) { + + checkNotNull(aggregateFunction, "aggregateFunction"); + checkNotNull(windowFunction, "windowFunction"); + checkNotNull(accumulatorType, "accumulatorType"); + checkNotNull(aggregateResultType, "aggregateResultType"); + checkNotNull(resultType, "resultType"); + + if (aggregateFunction instanceof RichFunction) { + throw new UnsupportedOperationException("This aggregate function cannot be a RichFunction."); + } + + //clean the closures + windowFunction = input.getExecutionEnvironment().clean(windowFunction); + aggregateFunction = input.getExecutionEnvironment().clean(aggregateFunction); + + String callLocation = Utils.getCallLocationName(); + String udfName = "WindowedStream." + callLocation; + + String opName; + KeySelector<T, K> keySel = input.getKeySelector(); + + OneInputStreamOperator<T, R> operator; + + if (evictor != null) { + @SuppressWarnings({"unchecked", "rawtypes"}) + TypeSerializer<StreamRecord<T>> streamRecordSerializer = + (TypeSerializer<StreamRecord<T>>) new StreamElementSerializer(input.getType().createSerializer(getExecutionEnvironment().getConfig())); + + ListStateDescriptor<StreamRecord<T>> stateDesc = + new ListStateDescriptor<>("window-contents", streamRecordSerializer); + + opName = "TriggerWindow(" + windowAssigner + ", " + stateDesc + ", " + trigger + ", " + evictor + ", " + udfName + ")"; + + operator = new EvictingWindowOperator<>(windowAssigner, + windowAssigner.getWindowSerializer(getExecutionEnvironment().getConfig()), + keySel, + input.getKeyType().createSerializer(getExecutionEnvironment().getConfig()), + stateDesc, + new InternalAggregateProcessWindowFunction<>(aggregateFunction, windowFunction), + trigger, + evictor, + allowedLateness); + + } else { + AggregatingStateDescriptor<T, ACC, V> stateDesc = new AggregatingStateDescriptor<>("window-contents", + aggregateFunction, accumulatorType.createSerializer(getExecutionEnvironment().getConfig())); + + opName = "TriggerWindow(" + windowAssigner + ", " + stateDesc + ", " + trigger + ", " + udfName + ")"; + + operator = new WindowOperator<>(windowAssigner, + windowAssigner.getWindowSerializer(getExecutionEnvironment().getConfig()), + keySel, + input.getKeyType().createSerializer(getExecutionEnvironment().getConfig()), + stateDesc, + new InternalSingleValueProcessWindowFunction<>(windowFunction), + trigger, + allowedLateness); + } + + return input.transform(opName, resultType, operator); + } + // ------------------------------------------------------------------------ // Window Function (apply) // ------------------------------------------------------------------------ http://git-wip-us.apache.org/repos/asf/flink/blob/fe2a3016/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalAggregateProcessWindowFunction.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalAggregateProcessWindowFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalAggregateProcessWindowFunction.java new file mode 100644 index 0000000..433da9b --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalAggregateProcessWindowFunction.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.streaming.runtime.operators.windowing.functions; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.functions.IterationRuntimeContext; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.java.operators.translation.WrappingFunction; +import org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction; +import org.apache.flink.streaming.api.windowing.windows.Window; +import org.apache.flink.util.Collector; + +import java.util.Collections; + +/** + * Internal window function for wrapping a {@link ProcessWindowFunction} that takes an + * {@code Iterable} and an {@link AggregateFunction}. + * + * @param <K> The key type + * @param <W> The window type + * @param <T> The type of the input to the AggregateFunction + * @param <ACC> The type of the AggregateFunction's accumulator + * @param <V> The type of the AggregateFunction's result, and the input to the WindowFunction + * @param <R> The result type of the WindowFunction + */ +public final class InternalAggregateProcessWindowFunction<T, ACC, V, R, K, W extends Window> + extends WrappingFunction<ProcessWindowFunction<V, R, K, W>> + implements InternalWindowFunction<Iterable<T>, R, K, W> { + + private static final long serialVersionUID = 1L; + + private final AggregateFunction<T, ACC, V> aggFunction; + + public InternalAggregateProcessWindowFunction( + AggregateFunction<T, ACC, V> aggFunction, + ProcessWindowFunction<V, R, K, W> windowFunction) { + super(windowFunction); + this.aggFunction = aggFunction; + } + + @Override + public void apply(K key, final W window, Iterable<T> input, Collector<R> out) throws Exception { + ProcessWindowFunction<V, R, K, W> wrappedFunction = this.wrappedFunction; + ProcessWindowFunction<V, R, K, W>.Context context = wrappedFunction.new Context() { + @Override + public W window() { + return window; + } + }; + + final ACC acc = aggFunction.createAccumulator(); + + for (T val : input) { + aggFunction.add(val, acc); + } + + wrappedFunction.process(key, context, Collections.singletonList(aggFunction.getResult(acc)), out); + } + + @Override + public RuntimeContext getRuntimeContext() { + throw new RuntimeException("This should never be called."); + } + + @Override + public IterationRuntimeContext getIterationRuntimeContext() { + throw new RuntimeException("This should never be called."); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/fe2a3016/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/windowing/functions/InternalWindowFunctionTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/windowing/functions/InternalWindowFunctionTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/windowing/functions/InternalWindowFunctionTest.java index 3c73035..e49a496 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/windowing/functions/InternalWindowFunctionTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/windowing/functions/InternalWindowFunctionTest.java @@ -19,6 +19,7 @@ package org.apache.flink.streaming.api.operators.windowing.functions; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.AggregateFunction; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -29,6 +30,7 @@ import org.apache.flink.streaming.api.functions.windowing.RichProcessWindowFunct import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction; import org.apache.flink.streaming.api.operators.OutputTypeConfigurable; import org.apache.flink.streaming.api.windowing.windows.TimeWindow; +import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalAggregateProcessWindowFunction; import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableAllWindowFunction; import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableProcessWindowFunction; import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableWindowFunction; @@ -39,6 +41,17 @@ import org.apache.flink.util.Collector; import org.hamcrest.collection.IsIterableContainingInOrder; import org.junit.Test; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.collection.IsMapContaining.hasEntry; +import static org.hamcrest.core.AllOf.allOf; import static org.mockito.Mockito.*; public class InternalWindowFunctionTest { @@ -288,6 +301,84 @@ public class InternalWindowFunctionTest { verify(mock).close(); } + @SuppressWarnings("unchecked") + @Test + public void testInternalAggregateProcessWindowFunction() throws Exception { + + AggregateProcessWindowFunctionMock mock = mock(AggregateProcessWindowFunctionMock.class); + + InternalAggregateProcessWindowFunction<Long, Set<Long>, Map<Long, Long>, String, Long, TimeWindow> windowFunction = + new InternalAggregateProcessWindowFunction<>(new AggregateFunction<Long, Set<Long>, Map<Long, Long>>() { + private static final long serialVersionUID = 1L; + + @Override + public Set<Long> createAccumulator() { + return new HashSet<>(); + } + + @Override + public void add(Long value, Set<Long> accumulator) { + accumulator.add(value); + } + + @Override + public Map<Long, Long> getResult(Set<Long> accumulator) { + Map<Long, Long> result = new HashMap<>(); + for (Long in : accumulator) { + result.put(in, in); + } + return result; + } + + @Override + public Set<Long> merge(Set<Long> a, Set<Long> b) { + a.addAll(b); + return a; + } + }, mock); + + // check setOutputType + TypeInformation<String> stringType = BasicTypeInfo.STRING_TYPE_INFO; + ExecutionConfig execConf = new ExecutionConfig(); + execConf.setParallelism(42); + + StreamingFunctionUtils.setOutputType(windowFunction, stringType, execConf); + verify(mock).setOutputType(stringType, execConf); + + // check open + Configuration config = new Configuration(); + + windowFunction.open(config); + verify(mock).open(config); + + // check setRuntimeContext + RuntimeContext rCtx = mock(RuntimeContext.class); + + windowFunction.setRuntimeContext(rCtx); + verify(mock).setRuntimeContext(rCtx); + + // check apply + TimeWindow w = mock(TimeWindow.class); + Collector<String> c = (Collector<String>) mock(Collector.class); + + List<Long> args = new LinkedList<>(); + args.add(23L); + args.add(24L); + + windowFunction.apply(42L, w, args, c); + verify(mock).process( + eq(42L), + (AggregateProcessWindowFunctionMock.Context) anyObject(), + (Iterable) argThat(containsInAnyOrder(allOf( + hasEntry(is(23L), is(23L)), + hasEntry(is(24L), is(24L))))), + eq(c)); + + // check close + windowFunction.close(); + verify(mock).close(); + } + public static class ProcessWindowFunctionMock extends RichProcessWindowFunction<Long, String, Long, TimeWindow> implements OutputTypeConfigurable<String> { @@ -301,6 +392,19 @@ public class InternalWindowFunctionTest { public void process(Long aLong, Context context, Iterable<Long> input, Collector<String> out) throws Exception { } } + public static class AggregateProcessWindowFunctionMock + extends RichProcessWindowFunction<Map<Long, Long>, String, Long, TimeWindow> + implements OutputTypeConfigurable<String> { + + private static final long serialVersionUID = 1L; + + @Override + public void setOutputType(TypeInformation<String> outTypeInfo, ExecutionConfig executionConfig) { } + + @Override + public void process(Long aLong, Context context, Iterable<Map<Long, Long>> input, Collector<String> out) throws Exception { } + } + public static class WindowFunctionMock extends RichWindowFunction<Long, String, Long, TimeWindow> implements OutputTypeConfigurable<String> { http://git-wip-us.apache.org/repos/asf/flink/blob/fe2a3016/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala ---------------------------------------------------------------------- diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala index 96ff334..a5fbeb9 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala @@ -326,6 +326,36 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) { accumulatorType, aggregationResultType, resultType)) } + /** + * Applies the given window function to each window. The window function is called for each + * evaluation of the window for each key individually. The output of the window function is + * interpreted as a regular non-windowed stream. + * + * Arriving data is pre-aggregated using the given aggregation function. + * + * @param preAggregator The aggregation function that is used for pre-aggregation + * @param windowFunction The window function. + * @return The data stream that is the result of applying the window function to the window. + */ + def aggregate[ACC: TypeInformation, V: TypeInformation, R: TypeInformation] + (preAggregator: AggregateFunction[T, ACC, V], + windowFunction: ProcessWindowFunction[V, R, K, W]): DataStream[R] = { + + val cleanedPreAggregator = clean(preAggregator) + val cleanedWindowFunction = clean(windowFunction) + + val applyFunction = new ScalaProcessWindowFunctionWrapper[V, R, K, W](cleanedWindowFunction) + + val accumulatorType: TypeInformation[ACC] = implicitly[TypeInformation[ACC]] + val aggregationResultType: TypeInformation[V] = implicitly[TypeInformation[V]] + val resultType: TypeInformation[R] = implicitly[TypeInformation[R]] + + asScalaStream(javaStream.aggregate( + cleanedPreAggregator, applyFunction, + accumulatorType, aggregationResultType, resultType)) + } + + // ---------------------------- fold() ------------------------------------ /**
