[FLINK-5250] Make AbstractUdfStreamOperator aware of WrappingFunction
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/8492d9b7 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/8492d9b7 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/8492d9b7 Branch: refs/heads/master Commit: 8492d9b7b92674db309f177e48782b11d9d1be5a Parents: ce02350 Author: Stefan Richter <[email protected]> Authored: Fri Jan 13 12:09:58 2017 +0100 Committer: Aljoscha Krettek <[email protected]> Committed: Fri Jan 13 22:47:25 2017 +0100 ---------------------------------------------------------------------- .../functions/util/StreamingFunctionUtils.java | 113 +++++++++++ .../windowing/ReduceApplyAllWindowFunction.java | 4 +- .../windowing/ReduceApplyWindowFunction.java | 4 +- .../operators/AbstractUdfStreamOperator.java | 50 +---- .../InternalIterableAllWindowFunction.java | 20 -- .../InternalIterableWindowFunction.java | 21 --- .../InternalSingleValueAllWindowFunction.java | 21 --- .../InternalSingleValueWindowFunction.java | 20 -- .../WrappingFunctionSnapshotRestoreTest.java | 187 +++++++++++++++++++ .../util/ScalaAllWindowFunctionWrapper.scala | 29 +-- .../util/ScalaWindowFunctionWrapper.scala | 36 ++-- 11 files changed, 322 insertions(+), 183 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/8492d9b7/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/util/StreamingFunctionUtils.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/util/StreamingFunctionUtils.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/util/StreamingFunctionUtils.java index f167f7f..d1d264f 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/util/StreamingFunctionUtils.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/util/StreamingFunctionUtils.java @@ -86,6 +86,119 @@ public final class StreamingFunctionUtils { return false; } + public static void snapshotFunctionState( + StateSnapshotContext context, + OperatorStateBackend backend, + Function userFunction) throws Exception { + + Preconditions.checkNotNull(context); + Preconditions.checkNotNull(backend); + + while (true) { + + if (trySnapshotFunctionState(context, backend, userFunction)) { + break; + } + + // inspect if the user function is wrapped, then unwrap and try again if we can snapshot the inner function + if (userFunction instanceof WrappingFunction) { + userFunction = ((WrappingFunction<?>) userFunction).getWrappedFunction(); + } else { + break; + } + } + } + + private static boolean trySnapshotFunctionState( + StateSnapshotContext context, + OperatorStateBackend backend, + Function userFunction) throws Exception { + + if (userFunction instanceof CheckpointedFunction) { + ((CheckpointedFunction) userFunction).snapshotState(context); + + return true; + } + + if (userFunction instanceof ListCheckpointed) { + @SuppressWarnings("unchecked") + List<Serializable> partitionableState = ((ListCheckpointed<Serializable>) userFunction). + snapshotState(context.getCheckpointId(), context.getCheckpointTimestamp()); + + ListState<Serializable> listState = backend. + getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME); + + listState.clear(); + + if (null != partitionableState) { + for (Serializable statePartition : partitionableState) { + listState.add(statePartition); + } + } + + return true; + } + + return false; + } + + public static void restoreFunctionState( + StateInitializationContext context, + Function userFunction) throws Exception { + + Preconditions.checkNotNull(context); + + while (true) { + + if (tryRestoreFunction(context, userFunction)) { + break; + } + + // inspect if the user function is wrapped, then unwrap and try again if we can restore the inner function + if (userFunction instanceof WrappingFunction) { + userFunction = ((WrappingFunction<?>) userFunction).getWrappedFunction(); + } else { + break; + } + } + } + + private static boolean tryRestoreFunction( + StateInitializationContext context, + Function userFunction) throws Exception { + + if (userFunction instanceof CheckpointedFunction) { + ((CheckpointedFunction) userFunction).initializeState(context); + + return true; + } + + if (context.isRestored() && userFunction instanceof ListCheckpointed) { + @SuppressWarnings("unchecked") + ListCheckpointed<Serializable> listCheckpointedFun = (ListCheckpointed<Serializable>) userFunction; + + ListState<Serializable> listState = context.getOperatorStateStore(). + getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME); + + List<Serializable> list = new ArrayList<>(); + + for (Serializable serializable : listState.get()) { + list.add(serializable); + } + + try { + listCheckpointedFun.restoreState(list); + } catch (Exception e) { + + throw new Exception("Failed to restore state to function: " + e.getMessage(), e); + } + + return true; + } + + return false; + } + /** * Private constructor to prevent instantiation. */ http://git-wip-us.apache.org/repos/asf/flink/blob/8492d9b7/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/ReduceApplyAllWindowFunction.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/ReduceApplyAllWindowFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/ReduceApplyAllWindowFunction.java index 5b8dd70..46a6456 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/ReduceApplyAllWindowFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/ReduceApplyAllWindowFunction.java @@ -33,13 +33,11 @@ public class ReduceApplyAllWindowFunction<W extends Window, T, R> private static final long serialVersionUID = 1L; private final ReduceFunction<T> reduceFunction; - private final AllWindowFunction<T, R, W> windowFunction; public ReduceApplyAllWindowFunction(ReduceFunction<T> reduceFunction, AllWindowFunction<T, R, W> windowFunction) { super(windowFunction); this.reduceFunction = reduceFunction; - this.windowFunction = windowFunction; } @Override @@ -53,6 +51,6 @@ public class ReduceApplyAllWindowFunction<W extends Window, T, R> curr = reduceFunction.reduce(curr, val); } } - windowFunction.apply(window, Collections.singletonList(curr), out); + wrappedFunction.apply(window, Collections.singletonList(curr), out); } } http://git-wip-us.apache.org/repos/asf/flink/blob/8492d9b7/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/ReduceApplyWindowFunction.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/ReduceApplyWindowFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/ReduceApplyWindowFunction.java index f896282..6e1ba27 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/ReduceApplyWindowFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/windowing/ReduceApplyWindowFunction.java @@ -33,13 +33,11 @@ public class ReduceApplyWindowFunction<K, W extends Window, T, R> private static final long serialVersionUID = 1L; private final ReduceFunction<T> reduceFunction; - private final WindowFunction<T, R, K, W> windowFunction; public ReduceApplyWindowFunction(ReduceFunction<T> reduceFunction, WindowFunction<T, R, K, W> windowFunction) { super(windowFunction); this.reduceFunction = reduceFunction; - this.windowFunction = windowFunction; } @Override @@ -53,6 +51,6 @@ public class ReduceApplyWindowFunction<K, W extends Window, T, R> curr = reduceFunction.reduce(curr, val); } } - windowFunction.apply(k, window, Collections.singletonList(curr), out); + wrappedFunction.apply(k, window, Collections.singletonList(curr), out); } } http://git-wip-us.apache.org/repos/asf/flink/blob/8492d9b7/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java index 9f67156..166287b 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java @@ -22,13 +22,11 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.Function; import org.apache.flink.api.common.functions.util.FunctionUtils; -import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.runtime.state.CheckpointListener; -import org.apache.flink.runtime.state.DefaultOperatorStateBackend; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.checkpoint.Checkpointed; @@ -43,8 +41,6 @@ import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Migration; import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; import static java.util.Objects.requireNonNull; @@ -94,7 +90,6 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function> @Override public void setup(StreamTask<?, ?> containingTask, StreamConfig config, Output<StreamRecord<OUT>> output) { super.setup(containingTask, config, output); - FunctionUtils.setFunctionRuntimeContext(userFunction, getRuntimeContext()); } @@ -102,54 +97,13 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function> @Override public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); - - if (userFunction instanceof CheckpointedFunction) { - ((CheckpointedFunction) userFunction).snapshotState(context); - } else if (userFunction instanceof ListCheckpointed) { - @SuppressWarnings("unchecked") - List<Serializable> partitionableState = ((ListCheckpointed<Serializable>) userFunction). - snapshotState(context.getCheckpointId(), context.getCheckpointTimestamp()); - - ListState<Serializable> listState = getOperatorStateBackend(). - getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME); - - listState.clear(); - - if (null != partitionableState) { - for (Serializable statePartition : partitionableState) { - listState.add(statePartition); - } - } - } - + StreamingFunctionUtils.snapshotFunctionState(context, getOperatorStateBackend(), userFunction); } @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - - if (userFunction instanceof CheckpointedFunction) { - ((CheckpointedFunction) userFunction).initializeState(context); - } else if (context.isRestored() && userFunction instanceof ListCheckpointed) { - @SuppressWarnings("unchecked") - ListCheckpointed<Serializable> listCheckpointedFun = (ListCheckpointed<Serializable>) userFunction; - - ListState<Serializable> listState = context.getOperatorStateStore(). - getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME); - - List<Serializable> list = new ArrayList<>(); - - for (Serializable serializable : listState.get()) { - list.add(serializable); - } - - try { - listCheckpointedFun.restoreState(list); - } catch (Exception e) { - throw new Exception("Failed to restore state to function: " + e.getMessage(), e); - } - } - + StreamingFunctionUtils.restoreFunctionState(context, userFunction); } @Override http://git-wip-us.apache.org/repos/asf/flink/blob/8492d9b7/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalIterableAllWindowFunction.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalIterableAllWindowFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalIterableAllWindowFunction.java index b2adc94..672bdb6 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalIterableAllWindowFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalIterableAllWindowFunction.java @@ -19,9 +19,7 @@ package org.apache.flink.streaming.runtime.operators.windowing.functions; import org.apache.flink.api.common.functions.IterationRuntimeContext; import org.apache.flink.api.common.functions.RuntimeContext; -import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.java.operators.translation.WrappingFunction; -import org.apache.flink.configuration.Configuration; import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; import org.apache.flink.streaming.api.windowing.windows.Window; import org.apache.flink.util.Collector; @@ -36,11 +34,8 @@ public final class InternalIterableAllWindowFunction<IN, OUT, W extends Window> private static final long serialVersionUID = 1L; - protected final AllWindowFunction<IN, OUT, W> wrappedFunction; - public InternalIterableAllWindowFunction(AllWindowFunction<IN, OUT, W> wrappedFunction) { super(wrappedFunction); - this.wrappedFunction = wrappedFunction; } @Override @@ -49,21 +44,6 @@ public final class InternalIterableAllWindowFunction<IN, OUT, W extends Window> } @Override - public void open(Configuration parameters) throws Exception { - FunctionUtils.openFunction(this.wrappedFunction, parameters); - } - - @Override - public void close() throws Exception { - FunctionUtils.closeFunction(this.wrappedFunction); - } - - @Override - public void setRuntimeContext(RuntimeContext t) { - FunctionUtils.setFunctionRuntimeContext(this.wrappedFunction, t); - } - - @Override public RuntimeContext getRuntimeContext() { throw new RuntimeException("This should never be called."); } http://git-wip-us.apache.org/repos/asf/flink/blob/8492d9b7/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalIterableWindowFunction.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalIterableWindowFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalIterableWindowFunction.java index 821d40a..895b31f 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalIterableWindowFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalIterableWindowFunction.java @@ -19,9 +19,7 @@ package org.apache.flink.streaming.runtime.operators.windowing.functions; import org.apache.flink.api.common.functions.IterationRuntimeContext; import org.apache.flink.api.common.functions.RuntimeContext; -import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.java.operators.translation.WrappingFunction; -import org.apache.flink.configuration.Configuration; import org.apache.flink.streaming.api.functions.windowing.WindowFunction; import org.apache.flink.streaming.api.windowing.windows.Window; import org.apache.flink.util.Collector; @@ -36,11 +34,8 @@ public final class InternalIterableWindowFunction<IN, OUT, KEY, W extends Window private static final long serialVersionUID = 1L; - protected final WindowFunction<IN, OUT, KEY, W> wrappedFunction; - public InternalIterableWindowFunction(WindowFunction<IN, OUT, KEY, W> wrappedFunction) { super(wrappedFunction); - this.wrappedFunction = wrappedFunction; } @Override @@ -49,21 +44,6 @@ public final class InternalIterableWindowFunction<IN, OUT, KEY, W extends Window } @Override - public void open(Configuration parameters) throws Exception { - FunctionUtils.openFunction(this.wrappedFunction, parameters); - } - - @Override - public void close() throws Exception { - FunctionUtils.closeFunction(this.wrappedFunction); - } - - @Override - public void setRuntimeContext(RuntimeContext t) { - FunctionUtils.setFunctionRuntimeContext(this.wrappedFunction, t); - } - - @Override public RuntimeContext getRuntimeContext() { throw new RuntimeException("This should never be called."); } @@ -71,6 +51,5 @@ public final class InternalIterableWindowFunction<IN, OUT, KEY, W extends Window @Override public IterationRuntimeContext getIterationRuntimeContext() { throw new RuntimeException("This should never be called."); - } } http://git-wip-us.apache.org/repos/asf/flink/blob/8492d9b7/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalSingleValueAllWindowFunction.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalSingleValueAllWindowFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalSingleValueAllWindowFunction.java index 7cdc31c..a34d3ec 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalSingleValueAllWindowFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalSingleValueAllWindowFunction.java @@ -19,9 +19,7 @@ package org.apache.flink.streaming.runtime.operators.windowing.functions; import org.apache.flink.api.common.functions.IterationRuntimeContext; import org.apache.flink.api.common.functions.RuntimeContext; -import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.java.operators.translation.WrappingFunction; -import org.apache.flink.configuration.Configuration; import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; import org.apache.flink.streaming.api.windowing.windows.Window; import org.apache.flink.util.Collector; @@ -38,11 +36,8 @@ public final class InternalSingleValueAllWindowFunction<IN, OUT, W extends Windo private static final long serialVersionUID = 1L; - protected AllWindowFunction<IN, OUT, W> wrappedFunction; - public InternalSingleValueAllWindowFunction(AllWindowFunction<IN, OUT, W> wrappedFunction) { super(wrappedFunction); - this.wrappedFunction = wrappedFunction; } @Override @@ -51,21 +46,6 @@ public final class InternalSingleValueAllWindowFunction<IN, OUT, W extends Windo } @Override - public void open(Configuration parameters) throws Exception { - FunctionUtils.openFunction(this.wrappedFunction, parameters); - } - - @Override - public void close() throws Exception { - FunctionUtils.closeFunction(this.wrappedFunction); - } - - @Override - public void setRuntimeContext(RuntimeContext t) { - FunctionUtils.setFunctionRuntimeContext(this.wrappedFunction, t); - } - - @Override public RuntimeContext getRuntimeContext() { throw new RuntimeException("This should never be called."); } @@ -73,6 +53,5 @@ public final class InternalSingleValueAllWindowFunction<IN, OUT, W extends Windo @Override public IterationRuntimeContext getIterationRuntimeContext() { throw new RuntimeException("This should never be called."); - } } http://git-wip-us.apache.org/repos/asf/flink/blob/8492d9b7/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalSingleValueWindowFunction.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalSingleValueWindowFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalSingleValueWindowFunction.java index e98872b..9a0a447 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalSingleValueWindowFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/functions/InternalSingleValueWindowFunction.java @@ -19,9 +19,7 @@ package org.apache.flink.streaming.runtime.operators.windowing.functions; import org.apache.flink.api.common.functions.IterationRuntimeContext; import org.apache.flink.api.common.functions.RuntimeContext; -import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.java.operators.translation.WrappingFunction; -import org.apache.flink.configuration.Configuration; import org.apache.flink.streaming.api.functions.windowing.WindowFunction; import org.apache.flink.streaming.api.windowing.windows.Window; import org.apache.flink.util.Collector; @@ -38,11 +36,8 @@ public final class InternalSingleValueWindowFunction<IN, OUT, KEY, W extends Win private static final long serialVersionUID = 1L; - protected WindowFunction<IN, OUT, KEY, W> wrappedFunction; - public InternalSingleValueWindowFunction(WindowFunction<IN, OUT, KEY, W> wrappedFunction) { super(wrappedFunction); - this.wrappedFunction = wrappedFunction; } @Override @@ -51,21 +46,6 @@ public final class InternalSingleValueWindowFunction<IN, OUT, KEY, W extends Win } @Override - public void open(Configuration parameters) throws Exception { - FunctionUtils.openFunction(this.wrappedFunction, parameters); - } - - @Override - public void close() throws Exception { - FunctionUtils.closeFunction(this.wrappedFunction); - } - - @Override - public void setRuntimeContext(RuntimeContext t) { - FunctionUtils.setFunctionRuntimeContext(this.wrappedFunction, t); - } - - @Override public RuntimeContext getRuntimeContext() { throw new RuntimeException("This should never be called."); } http://git-wip-us.apache.org/repos/asf/flink/blob/8492d9b7/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/WrappingFunctionSnapshotRestoreTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/WrappingFunctionSnapshotRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/WrappingFunctionSnapshotRestoreTest.java new file mode 100644 index 0000000..b1689f9 --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/WrappingFunctionSnapshotRestoreTest.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators; + +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.java.operators.translation.WrappingFunction; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +public class WrappingFunctionSnapshotRestoreTest { + + + @Test + public void testSnapshotAndRestoreWrappedCheckpointedFunction() throws Exception { + + StreamMap<Integer, Integer> operator = new StreamMap<>( + new WrappingTestFun(new WrappingTestFun(new InnerTestFun()))); + + OneInputStreamOperatorTestHarness<Integer, Integer> testHarness = + new OneInputStreamOperatorTestHarness<>(operator); + + testHarness.setup(); + testHarness.open(); + + testHarness.processElement(new StreamRecord<>(5, 12L)); + + // snapshot and restore from scratch + OperatorStateHandles snapshot = testHarness.snapshot(0, 0); + + testHarness.close(); + + InnerTestFun innerTestFun = new InnerTestFun(); + operator = new StreamMap<>(new WrappingTestFun(new WrappingTestFun(innerTestFun))); + + testHarness = new OneInputStreamOperatorTestHarness<>(operator); + + testHarness.setup(); + testHarness.initializeState(snapshot); + testHarness.open(); + + Assert.assertTrue(innerTestFun.wasRestored); + testHarness.close(); + } + + @Test + public void testSnapshotAndRestoreWrappedListCheckpointed() throws Exception { + + StreamMap<Integer, Integer> operator = new StreamMap<>( + new WrappingTestFun(new WrappingTestFun(new InnerTestFunList()))); + + OneInputStreamOperatorTestHarness<Integer, Integer> testHarness = + new OneInputStreamOperatorTestHarness<>(operator); + + testHarness.setup(); + testHarness.open(); + + testHarness.processElement(new StreamRecord<>(5, 12L)); + + // snapshot and restore from scratch + OperatorStateHandles snapshot = testHarness.snapshot(0, 0); + + testHarness.close(); + + InnerTestFunList innerTestFun = new InnerTestFunList(); + operator = new StreamMap<>(new WrappingTestFun(new WrappingTestFun(innerTestFun))); + + testHarness = new OneInputStreamOperatorTestHarness<>(operator); + + testHarness.setup(); + testHarness.initializeState(snapshot); + testHarness.open(); + + Assert.assertTrue(innerTestFun.wasRestored); + testHarness.close(); + } + + static class WrappingTestFun + extends WrappingFunction<MapFunction<Integer, Integer>> implements MapFunction<Integer, Integer> { + + private static final long serialVersionUID = 1L; + + public WrappingTestFun(MapFunction<Integer, Integer> wrappedFunction) { + super(wrappedFunction); + } + + @Override + public Integer map(Integer value) throws Exception { + return value; + } + } + + static class InnerTestFun + extends AbstractRichFunction implements MapFunction<Integer, Integer>, CheckpointedFunction { + + private static final long serialVersionUID = 1L; + + private ListState<Integer> serializableListState; + private boolean wasRestored; + + public InnerTestFun() { + wasRestored = false; + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + if (!wasRestored) { + serializableListState.add(42); + } + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + serializableListState = context.getOperatorStateStore().getSerializableListState("test-state"); + if (context.isRestored()) { + Iterator<Integer> integers = serializableListState.get().iterator(); + int act = integers.next(); + Assert.assertEquals(42, act); + Assert.assertFalse(integers.hasNext()); + wasRestored = true; + } + } + + @Override + public Integer map(Integer value) throws Exception { + return value; + } + } + + static class InnerTestFunList + extends AbstractRichFunction implements MapFunction<Integer, Integer>, ListCheckpointed<Integer> { + + private static final long serialVersionUID = 1L; + + private boolean wasRestored; + + public InnerTestFunList() { + wasRestored = false; + } + + @Override + public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception { + return Collections.singletonList(42); + } + + @Override + public void restoreState(List<Integer> state) throws Exception { + Assert.assertEquals(1, state.size()); + int val = state.get(0); + Assert.assertEquals(42, val); + wasRestored = true; + } + + @Override + public Integer map(Integer value) throws Exception { + return value; + } + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/8492d9b7/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/function/util/ScalaAllWindowFunctionWrapper.scala ---------------------------------------------------------------------- diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/function/util/ScalaAllWindowFunctionWrapper.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/function/util/ScalaAllWindowFunctionWrapper.scala index 39142c2..6db7236 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/function/util/ScalaAllWindowFunctionWrapper.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/function/util/ScalaAllWindowFunctionWrapper.scala @@ -18,9 +18,8 @@ package org.apache.flink.streaming.api.scala.function.util -import org.apache.flink.api.common.functions.util.FunctionUtils import org.apache.flink.api.common.functions.{IterationRuntimeContext, RichFunction, RuntimeContext} -import org.apache.flink.configuration.Configuration +import org.apache.flink.api.java.operators.translation.WrappingFunction import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction => JAllWindowFunction} import org.apache.flink.streaming.api.scala.function.AllWindowFunction import org.apache.flink.streaming.api.windowing.windows.Window @@ -35,34 +34,20 @@ import scala.collection.JavaConverters._ * - Scala WindowFunction: scala.Iterable * - Java WindowFunction: java.lang.Iterable */ -final class ScalaAllWindowFunctionWrapper[IN, OUT, W <: Window]( - private[this] val func: AllWindowFunction[IN, OUT, W]) - extends JAllWindowFunction[IN, OUT, W] with RichFunction { +final class ScalaAllWindowFunctionWrapper[IN, OUT, W <: Window](func: AllWindowFunction[IN, OUT, W]) + extends WrappingFunction[AllWindowFunction[IN, OUT, W]](func) + with JAllWindowFunction[IN, OUT, W] with RichFunction { @throws(classOf[Exception]) override def apply(window: W, input: java.lang.Iterable[IN], out: Collector[OUT]) { - func.apply(window, input.asScala, out) + wrappedFunction.apply(window, input.asScala, out) } - @throws(classOf[Exception]) - override def open(parameters: Configuration) { - FunctionUtils.openFunction(func, parameters) - } - - @throws(classOf[Exception]) - override def close() { - FunctionUtils.closeFunction(func) - } - - override def setRuntimeContext(t: RuntimeContext) { - FunctionUtils.setFunctionRuntimeContext(func, t) - } - - override def getRuntimeContext(): RuntimeContext = { + override def getRuntimeContext: RuntimeContext = { throw new RuntimeException("This should never be called") } - override def getIterationRuntimeContext(): IterationRuntimeContext = { + override def getIterationRuntimeContext: IterationRuntimeContext = { throw new RuntimeException("This should never be called") } } http://git-wip-us.apache.org/repos/asf/flink/blob/8492d9b7/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/function/util/ScalaWindowFunctionWrapper.scala ---------------------------------------------------------------------- diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/function/util/ScalaWindowFunctionWrapper.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/function/util/ScalaWindowFunctionWrapper.scala index 1d74b6c..a074cd9 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/function/util/ScalaWindowFunctionWrapper.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/function/util/ScalaWindowFunctionWrapper.scala @@ -18,10 +18,9 @@ package org.apache.flink.streaming.api.scala.function.util -import org.apache.flink.api.common.functions.{IterationRuntimeContext, RuntimeContext, RichFunction} -import org.apache.flink.api.common.functions.util.FunctionUtils -import org.apache.flink.configuration.Configuration -import org.apache.flink.streaming.api.functions.windowing.{ WindowFunction => JWindowFunction } +import org.apache.flink.api.common.functions.{IterationRuntimeContext, RichFunction, RuntimeContext} +import org.apache.flink.api.java.operators.translation.WrappingFunction +import org.apache.flink.streaming.api.functions.windowing.{WindowFunction => JWindowFunction} import org.apache.flink.streaming.api.scala.function.WindowFunction import org.apache.flink.streaming.api.windowing.windows.Window import org.apache.flink.util.Collector @@ -35,34 +34,21 @@ import scala.collection.JavaConverters._ * - Scala WindowFunction: scala.Iterable * - Java WindowFunction: java.lang.Iterable */ -final class ScalaWindowFunctionWrapper[IN, OUT, KEY, W <: Window]( - private[this] val func: WindowFunction[IN, OUT, KEY, W]) - extends JWindowFunction[IN, OUT, KEY, W] with RichFunction { - - @throws(classOf[Exception]) - override def apply(key: KEY, window: W, input: java.lang.Iterable[IN], out: Collector[OUT]) { - func.apply(key, window, input.asScala, out) - } +final class ScalaWindowFunctionWrapper[IN, OUT, KEY, W <: Window] +(func: WindowFunction[IN, OUT, KEY, W]) + extends WrappingFunction[WindowFunction[IN, OUT, KEY, W]](func) + with JWindowFunction[IN, OUT, KEY, W] with RichFunction { @throws(classOf[Exception]) - override def open(parameters: Configuration) { - FunctionUtils.openFunction(func, parameters) - } - - @throws(classOf[Exception]) - override def close() { - FunctionUtils.closeFunction(func) - } - - override def setRuntimeContext(t: RuntimeContext) { - FunctionUtils.setFunctionRuntimeContext(func, t) + override def apply(key: KEY, window: W, input: java.lang.Iterable[IN], out: Collector[OUT]) { + wrappedFunction.apply(key, window, input.asScala, out) } - override def getRuntimeContext(): RuntimeContext = { + override def getRuntimeContext: RuntimeContext = { throw new RuntimeException("This should never be called") } - override def getIterationRuntimeContext(): IterationRuntimeContext = { + override def getIterationRuntimeContext: IterationRuntimeContext = { throw new RuntimeException("This should never be called") } }
