[FLINK-4997] Add ProcessWindowFunction support for .aggregate()

Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/fe2a3016
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/fe2a3016
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/fe2a3016

Branch: refs/heads/master
Commit: fe2a3016f98e45d0c94a3fa1ed8c17b89a516859
Parents: 4f047e1
Author: Aljoscha Krettek <[email protected]>
Authored: Tue Feb 7 14:38:25 2017 +0100
Committer: Aljoscha Krettek <[email protected]>
Committed: Fri Feb 17 17:15:51 2017 +0100

----------------------------------------------------------------------
 .../api/datastream/WindowedStream.java          | 128 +++++++++++++++++++
 .../InternalAggregateProcessWindowFunction.java |  84 ++++++++++++
 .../functions/InternalWindowFunctionTest.java   | 104 +++++++++++++++
 .../streaming/api/scala/WindowedStream.scala    |  30 +++++
 4 files changed, 346 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/fe2a3016/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java
index 45eaae5..6809df0 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java
@@ -64,6 +64,7 @@ import 
org.apache.flink.streaming.api.windowing.windows.Window;
 import 
org.apache.flink.streaming.runtime.operators.windowing.AccumulatingProcessingTimeWindowOperator;
 import 
org.apache.flink.streaming.runtime.operators.windowing.AggregatingProcessingTimeWindowOperator;
 import 
org.apache.flink.streaming.runtime.operators.windowing.EvictingWindowOperator;
+import 
org.apache.flink.streaming.runtime.operators.windowing.functions.InternalAggregateProcessWindowFunction;
 import 
org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableProcessWindowFunction;
 import 
org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableWindowFunction;
 import 
org.apache.flink.streaming.runtime.operators.windowing.functions.InternalSingleValueProcessWindowFunction;
@@ -906,6 +907,133 @@ public class WindowedStream<T, K, W extends Window> {
                return input.transform(opName, resultType, operator);
        }
 
+       /**
+        * Applies the given window function to each window. The window 
function is called for each
+        * evaluation of the window for each key individually. The output of 
the window function is
+        * interpreted as a regular non-windowed stream.
+        *
+        * <p>Arriving data is incrementally aggregated using the given 
aggregate function. This means
+        * that the window function typically has only a single value to 
process when called.
+        *
+        * @param aggFunction The aggregate function that is used for 
incremental aggregation.
+        * @param windowFunction The window function.
+        *
+        * @return The data stream that is the result of applying the window 
function to the window.
+        *
+        * @param <ACC> The type of the AggregateFunction's accumulator
+        * @param <V> The type of AggregateFunction's result, and the 
WindowFunction's input
+        * @param <R> The type of the elements in the resulting stream, equal 
to the
+        *            WindowFunction's result type
+        */
+       @PublicEvolving
+       public <ACC, V, R> SingleOutputStreamOperator<R> aggregate(
+                       AggregateFunction<T, ACC, V> aggFunction,
+                       ProcessWindowFunction<V, R, K, W> windowFunction) {
+
+               checkNotNull(aggFunction, "aggFunction");
+               checkNotNull(windowFunction, "windowFunction");
+
+               TypeInformation<ACC> accumulatorType = 
TypeExtractor.getAggregateFunctionAccumulatorType(
+                               aggFunction, input.getType(), null, false);
+
+               TypeInformation<V> aggResultType = 
TypeExtractor.getAggregateFunctionReturnType(
+                               aggFunction, input.getType(), null, false);
+
+               TypeInformation<R> resultType = 
TypeExtractor.getUnaryOperatorReturnType(
+                               windowFunction, ProcessWindowFunction.class, 
true, true, aggResultType, null, false);
+
+               return aggregate(aggFunction, windowFunction, accumulatorType, 
aggResultType, resultType);
+       }
+
+       /**
+        * Applies the given window function to each window. The window 
function is called for each
+        * evaluation of the window for each key individually. The output of 
the window function is
+        * interpreted as a regular non-windowed stream.
+        *
+        * <p>Arriving data is incrementally aggregated using the given 
aggregate function. This means
+        * that the window function typically has only a single value to 
process when called.
+        *
+        * @param aggregateFunction The aggregation function that is used for 
incremental aggregation.
+        * @param windowFunction The window function.
+        * @param accumulatorType Type information for the internal accumulator 
type of the aggregation function
+        * @param resultType Type information for the result type of the window 
function
+        *
+        * @return The data stream that is the result of applying the window 
function to the window.
+        *
+        * @param <ACC> The type of the AggregateFunction's accumulator
+        * @param <V> The type of AggregateFunction's result, and the 
WindowFunction's input
+        * @param <R> The type of the elements in the resulting stream, equal 
to the
+        *            WindowFunction's result type
+        */
+       @PublicEvolving
+       public <ACC, V, R> SingleOutputStreamOperator<R> aggregate(
+                       AggregateFunction<T, ACC, V> aggregateFunction,
+                       ProcessWindowFunction<V, R, K, W> windowFunction,
+                       TypeInformation<ACC> accumulatorType,
+                       TypeInformation<V> aggregateResultType,
+                       TypeInformation<R> resultType) {
+
+               checkNotNull(aggregateFunction, "aggregateFunction");
+               checkNotNull(windowFunction, "windowFunction");
+               checkNotNull(accumulatorType, "accumulatorType");
+               checkNotNull(aggregateResultType, "aggregateResultType");
+               checkNotNull(resultType, "resultType");
+
+               if (aggregateFunction instanceof RichFunction) {
+                       throw new UnsupportedOperationException("This aggregate 
function cannot be a RichFunction.");
+               }
+
+               //clean the closures
+               windowFunction = 
input.getExecutionEnvironment().clean(windowFunction);
+               aggregateFunction = 
input.getExecutionEnvironment().clean(aggregateFunction);
+
+               String callLocation = Utils.getCallLocationName();
+               String udfName = "WindowedStream." + callLocation;
+
+               String opName;
+               KeySelector<T, K> keySel = input.getKeySelector();
+
+               OneInputStreamOperator<T, R> operator;
+
+               if (evictor != null) {
+                       @SuppressWarnings({"unchecked", "rawtypes"})
+                       TypeSerializer<StreamRecord<T>> streamRecordSerializer =
+                                       (TypeSerializer<StreamRecord<T>>) new 
StreamElementSerializer(input.getType().createSerializer(getExecutionEnvironment().getConfig()));
+
+                       ListStateDescriptor<StreamRecord<T>> stateDesc =
+                                       new 
ListStateDescriptor<>("window-contents", streamRecordSerializer);
+
+                       opName = "TriggerWindow(" + windowAssigner + ", " + 
stateDesc + ", " + trigger + ", " + evictor + ", " + udfName + ")";
+
+                       operator = new EvictingWindowOperator<>(windowAssigner,
+                                       
windowAssigner.getWindowSerializer(getExecutionEnvironment().getConfig()),
+                                       keySel,
+                                       
input.getKeyType().createSerializer(getExecutionEnvironment().getConfig()),
+                                       stateDesc,
+                                       new 
InternalAggregateProcessWindowFunction<>(aggregateFunction, windowFunction),
+                                       trigger,
+                                       evictor,
+                                       allowedLateness);
+
+               } else {
+                       AggregatingStateDescriptor<T, ACC, V> stateDesc = new 
AggregatingStateDescriptor<>("window-contents",
+                                       aggregateFunction, 
accumulatorType.createSerializer(getExecutionEnvironment().getConfig()));
+
+                       opName = "TriggerWindow(" + windowAssigner + ", " + 
stateDesc + ", " + trigger + ", " + udfName + ")";
+
+                       operator = new WindowOperator<>(windowAssigner,
+                                       
windowAssigner.getWindowSerializer(getExecutionEnvironment().getConfig()),
+                                       keySel,
+                                       
input.getKeyType().createSerializer(getExecutionEnvironment().getConfig()),
+                                       stateDesc,
+                                       new 
InternalSingleValueProcessWindowFunction<>(windowFunction),
+                                       trigger,
+                                       allowedLateness);
+               }
+
+               return input.transform(opName, resultType, operator);
+       }
+
        // 
------------------------------------------------------------------------
        //  Window Function (apply)
        // 
------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/fe2a3016/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalAggregateProcessWindowFunction.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalAggregateProcessWindowFunction.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalAggregateProcessWindowFunction.java
new file mode 100644
index 0000000..433da9b
--- /dev/null
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalAggregateProcessWindowFunction.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.streaming.runtime.operators.windowing.functions;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.IterationRuntimeContext;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.java.operators.translation.WrappingFunction;
+import 
org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction;
+import org.apache.flink.streaming.api.windowing.windows.Window;
+import org.apache.flink.util.Collector;
+
+import java.util.Collections;
+
+/**
+ * Internal window function for wrapping a {@link ProcessWindowFunction} that 
takes an
+ * {@code Iterable} and an {@link AggregateFunction}.
+ *
+ * @param <K> The key type
+ * @param <W> The window type
+ * @param <T> The type of the input to the AggregateFunction
+ * @param <ACC> The type of the AggregateFunction's accumulator
+ * @param <V> The type of the AggregateFunction's result, and the input to the 
WindowFunction
+ * @param <R> The result type of the WindowFunction
+ */
+public final class InternalAggregateProcessWindowFunction<T, ACC, V, R, K, W 
extends Window>
+               extends WrappingFunction<ProcessWindowFunction<V, R, K, W>>
+               implements InternalWindowFunction<Iterable<T>, R, K, W> {
+
+       private static final long serialVersionUID = 1L;
+
+       private final AggregateFunction<T, ACC, V> aggFunction;
+
+       public InternalAggregateProcessWindowFunction(
+                       AggregateFunction<T, ACC, V> aggFunction,
+                       ProcessWindowFunction<V, R, K, W> windowFunction) {
+               super(windowFunction);
+               this.aggFunction = aggFunction;
+       }
+       
+       @Override
+       public void apply(K key, final W window, Iterable<T> input, 
Collector<R> out) throws Exception {
+               ProcessWindowFunction<V, R, K, W> wrappedFunction = 
this.wrappedFunction;
+               ProcessWindowFunction<V, R, K, W>.Context context = 
wrappedFunction.new Context() {
+                       @Override
+                       public W window() {
+                               return window;
+                       }
+               };
+
+               final ACC acc = aggFunction.createAccumulator();
+
+               for (T val : input) {
+                       aggFunction.add(val, acc);
+               }
+
+               wrappedFunction.process(key, context, 
Collections.singletonList(aggFunction.getResult(acc)), out);
+       }
+
+       @Override
+       public RuntimeContext getRuntimeContext() {
+               throw new RuntimeException("This should never be called.");
+       }
+
+       @Override
+       public IterationRuntimeContext getIterationRuntimeContext() {
+               throw new RuntimeException("This should never be called.");
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/fe2a3016/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/windowing/functions/InternalWindowFunctionTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/windowing/functions/InternalWindowFunctionTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/windowing/functions/InternalWindowFunctionTest.java
index 3c73035..e49a496 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/windowing/functions/InternalWindowFunctionTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/windowing/functions/InternalWindowFunctionTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.streaming.api.operators.windowing.functions;
 
 import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.functions.AggregateFunction;
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
@@ -29,6 +30,7 @@ import 
org.apache.flink.streaming.api.functions.windowing.RichProcessWindowFunct
 import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction;
 import org.apache.flink.streaming.api.operators.OutputTypeConfigurable;
 import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import 
org.apache.flink.streaming.runtime.operators.windowing.functions.InternalAggregateProcessWindowFunction;
 import 
org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableAllWindowFunction;
 import 
org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableProcessWindowFunction;
 import 
org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableWindowFunction;
@@ -39,6 +41,17 @@ import org.apache.flink.util.Collector;
 import org.hamcrest.collection.IsIterableContainingInOrder;
 import org.junit.Test;
 
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.collection.IsMapContaining.hasEntry;
+import static org.hamcrest.core.AllOf.allOf;
 import static org.mockito.Mockito.*;
 
 public class InternalWindowFunctionTest {
@@ -288,6 +301,84 @@ public class InternalWindowFunctionTest {
                verify(mock).close();
        }
 
+       @SuppressWarnings("unchecked")
+       @Test
+       public void testInternalAggregateProcessWindowFunction() throws 
Exception {
+
+               AggregateProcessWindowFunctionMock mock = 
mock(AggregateProcessWindowFunctionMock.class);
+
+               InternalAggregateProcessWindowFunction<Long, Set<Long>, 
Map<Long, Long>, String, Long, TimeWindow> windowFunction =
+                               new 
InternalAggregateProcessWindowFunction<>(new AggregateFunction<Long, Set<Long>, 
Map<Long, Long>>() {
+                                       private static final long 
serialVersionUID = 1L;
+                                       
+                                       @Override
+                                       public Set<Long> createAccumulator() {
+                                               return new HashSet<>();
+                                       }
+
+                                       @Override
+                                       public void add(Long value, Set<Long> 
accumulator) {
+                                               accumulator.add(value);
+                                       }
+
+                                       @Override
+                                       public Map<Long, Long> 
getResult(Set<Long> accumulator) {
+                                               Map<Long, Long> result = new 
HashMap<>();
+                                               for (Long in : accumulator) {
+                                                       result.put(in, in);
+                                               }
+                                               return result;
+                                       }
+
+                                       @Override
+                                       public Set<Long> merge(Set<Long> a, 
Set<Long> b) {
+                                               a.addAll(b);
+                                               return a;
+                                       }
+                               }, mock);
+
+               // check setOutputType
+               TypeInformation<String> stringType = 
BasicTypeInfo.STRING_TYPE_INFO;
+               ExecutionConfig execConf = new ExecutionConfig();
+               execConf.setParallelism(42);
+
+               StreamingFunctionUtils.setOutputType(windowFunction, 
stringType, execConf);
+               verify(mock).setOutputType(stringType, execConf);
+
+               // check open
+               Configuration config = new Configuration();
+
+               windowFunction.open(config);
+               verify(mock).open(config);
+
+               // check setRuntimeContext
+               RuntimeContext rCtx = mock(RuntimeContext.class);
+
+               windowFunction.setRuntimeContext(rCtx);
+               verify(mock).setRuntimeContext(rCtx);
+
+               // check apply
+               TimeWindow w = mock(TimeWindow.class);
+               Collector<String> c = (Collector<String>) mock(Collector.class);
+
+               List<Long> args = new LinkedList<>();
+               args.add(23L);
+               args.add(24L);
+               
+               windowFunction.apply(42L, w, args, c);
+               verify(mock).process(
+                               eq(42L),
+                               (AggregateProcessWindowFunctionMock.Context) 
anyObject(),
+                               (Iterable) argThat(containsInAnyOrder(allOf(
+                                               hasEntry(is(23L), is(23L)),
+                                               hasEntry(is(24L), is(24L))))),
+                               eq(c));
+
+               // check close
+               windowFunction.close();
+               verify(mock).close();
+       }
+
        public static class ProcessWindowFunctionMock
                extends RichProcessWindowFunction<Long, String, Long, 
TimeWindow>
                implements OutputTypeConfigurable<String> {
@@ -301,6 +392,19 @@ public class InternalWindowFunctionTest {
                public void process(Long aLong, Context context, Iterable<Long> 
input, Collector<String> out) throws Exception { }
        }
 
+       public static class AggregateProcessWindowFunctionMock
+                       extends RichProcessWindowFunction<Map<Long, Long>, 
String, Long, TimeWindow>
+                       implements OutputTypeConfigurable<String> {
+
+               private static final long serialVersionUID = 1L;
+
+               @Override
+               public void setOutputType(TypeInformation<String> outTypeInfo, 
ExecutionConfig executionConfig) { }
+
+               @Override
+               public void process(Long aLong, Context context, 
Iterable<Map<Long, Long>> input, Collector<String> out) throws Exception { }
+       }
+
        public static class WindowFunctionMock
                extends RichWindowFunction<Long, String, Long, TimeWindow>
                implements OutputTypeConfigurable<String> {

http://git-wip-us.apache.org/repos/asf/flink/blob/fe2a3016/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala
----------------------------------------------------------------------
diff --git 
a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala
 
b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala
index 96ff334..a5fbeb9 100644
--- 
a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala
+++ 
b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala
@@ -326,6 +326,36 @@ class WindowedStream[T, K, W <: Window](javaStream: 
JavaWStream[T, K, W]) {
       accumulatorType, aggregationResultType, resultType))
   }
 
+  /**
+    * Applies the given window function to each window. The window function is 
called for each
+    * evaluation of the window for each key individually. The output of the 
window function is
+    * interpreted as a regular non-windowed stream.
+    *
+    * Arriving data is pre-aggregated using the given aggregation function.
+    *
+    * @param preAggregator The aggregation function that is used for 
pre-aggregation
+    * @param windowFunction The window function.
+    * @return The data stream that is the result of applying the window 
function to the window.
+    */
+  def aggregate[ACC: TypeInformation, V: TypeInformation, R: TypeInformation]
+  (preAggregator: AggregateFunction[T, ACC, V],
+   windowFunction: ProcessWindowFunction[V, R, K, W]): DataStream[R] = {
+
+    val cleanedPreAggregator = clean(preAggregator)
+    val cleanedWindowFunction = clean(windowFunction)
+
+    val applyFunction = new ScalaProcessWindowFunctionWrapper[V, R, K, 
W](cleanedWindowFunction)
+
+    val accumulatorType: TypeInformation[ACC] = 
implicitly[TypeInformation[ACC]]
+    val aggregationResultType: TypeInformation[V] = 
implicitly[TypeInformation[V]]
+    val resultType: TypeInformation[R] = implicitly[TypeInformation[R]]
+
+    asScalaStream(javaStream.aggregate(
+      cleanedPreAggregator, applyFunction,
+      accumulatorType, aggregationResultType, resultType))
+  }
+
+
   // ---------------------------- fold() ------------------------------------
 
   /**

Reply via email to