Repository: flink Updated Branches: refs/heads/release-1.5 5df2bc5c9 -> 0c06852b3
[FLINK-8659] Add migration itcases for broadcast state. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/0c06852b Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/0c06852b Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/0c06852b Branch: refs/heads/release-1.5 Commit: 0c06852b3cecd414aac623ffd155ebc2a2a31336 Parents: 5df2bc5 Author: kkloudas <[email protected]> Authored: Thu May 3 10:05:13 2018 +0200 Committer: kkloudas <[email protected]> Committed: Fri May 18 15:10:54 2018 +0200 ---------------------------------------------------------------------- .../checkpointing/utils/MigrationTestUtils.java | 336 ++++++++++++++ .../utils/SavepointMigrationTestBase.java | 8 +- .../StatefulJobSavepointMigrationITCase.java | 308 +------------ ...atefulJobWBroadcastStateMigrationITCase.java | 391 ++++++++++++++++ .../_metadata | Bin 0 -> 20936 bytes .../_metadata | Bin 0 -> 20936 bytes .../_metadata | Bin 0 -> 220470 bytes .../_metadata | Bin 0 -> 220470 bytes ...tefulJobWBroadcastStateMigrationITCase.scala | 450 +++++++++++++++++++ 9 files changed, 1194 insertions(+), 299 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/MigrationTestUtils.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/MigrationTestUtils.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/MigrationTestUtils.java new file mode 100644 index 0000000..9314496 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/MigrationTestUtils.java @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.checkpointing.utils; + +import org.apache.flink.api.common.accumulators.IntCounter; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; +import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; +import org.apache.flink.streaming.api.functions.source.RichSourceFunction; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.streaming.api.watermark.Watermark; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertThat; + +/** + * A utility class containing common functions/classes used by multiple migration tests. + */ +public class MigrationTestUtils { + + /** + * A non-parallel source with list state used for testing. + */ + public static class CheckpointingNonParallelSourceWithListState + implements SourceFunction<Tuple2<Long, Long>>, CheckpointedFunction { + + static final ListStateDescriptor<String> STATE_DESCRIPTOR = + new ListStateDescriptor<>("source-state", StringSerializer.INSTANCE); + + static final String CHECKPOINTED_STRING = "Here be dragons!"; + static final String CHECKPOINTED_STRING_1 = "Here be more dragons!"; + static final String CHECKPOINTED_STRING_2 = "Here be yet more dragons!"; + static final String CHECKPOINTED_STRING_3 = "Here be the mostest dragons!"; + + private static final long serialVersionUID = 1L; + + private volatile boolean isRunning = true; + + private final int numElements; + + private transient ListState<String> unionListState; + + CheckpointingNonParallelSourceWithListState(int numElements) { + this.numElements = numElements; + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + unionListState.clear(); + unionListState.add(CHECKPOINTED_STRING); + unionListState.add(CHECKPOINTED_STRING_1); + unionListState.add(CHECKPOINTED_STRING_2); + unionListState.add(CHECKPOINTED_STRING_3); + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + unionListState = context.getOperatorStateStore().getListState( + STATE_DESCRIPTOR); + } + + @Override + public void run(SourceContext<Tuple2<Long, Long>> ctx) throws Exception { + + ctx.emitWatermark(new Watermark(0)); + + synchronized (ctx.getCheckpointLock()) { + for (long i = 0; i < numElements; i++) { + ctx.collect(new Tuple2<>(i, i)); + } + } + + // don't emit a final watermark so that we don't trigger the registered event-time + // timers + while (isRunning) { + Thread.sleep(20); + } + } + + @Override + public void cancel() { + isRunning = false; + } + } + + /** + * A non-parallel source with union state used to verify the restored state of + * {@link CheckpointingNonParallelSourceWithListState}. + */ + public static class CheckingNonParallelSourceWithListState + extends RichSourceFunction<Tuple2<Long, Long>> implements CheckpointedFunction { + + private static final long serialVersionUID = 1L; + + static final String SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR = CheckingNonParallelSourceWithListState.class + "_RESTORE_CHECK"; + + private volatile boolean isRunning = true; + + private final int numElements; + + CheckingNonParallelSourceWithListState(int numElements) { + this.numElements = numElements; + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + ListState<String> unionListState = context.getOperatorStateStore().getListState( + CheckpointingNonParallelSourceWithListState.STATE_DESCRIPTOR); + + if (context.isRestored()) { + assertThat(unionListState.get(), + containsInAnyOrder( + CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING, + CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING_1, + CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING_2, + CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING_3)); + + getRuntimeContext().addAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, new IntCounter()); + getRuntimeContext().getAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR).add(1); + } else { + throw new RuntimeException( + "This source should always be restored because it's only used when restoring from a savepoint."); + } + } + + @Override + public void run(SourceContext<Tuple2<Long, Long>> ctx) throws Exception { + + // immediately trigger any set timers + ctx.emitWatermark(new Watermark(1000)); + + synchronized (ctx.getCheckpointLock()) { + for (long i = 0; i < numElements; i++) { + ctx.collect(new Tuple2<>(i, i)); + } + } + + while (isRunning) { + Thread.sleep(20); + } + } + + @Override + public void cancel() { + isRunning = false; + } + } + + /** + * A parallel source with union state used for testing. + */ + public static class CheckpointingParallelSourceWithUnionListState + extends RichSourceFunction<Tuple2<Long, Long>> implements CheckpointedFunction { + + static final ListStateDescriptor<String> STATE_DESCRIPTOR = + new ListStateDescriptor<>("source-state", StringSerializer.INSTANCE); + + static final String[] CHECKPOINTED_STRINGS = { + "Here be dragons!", + "Here be more dragons!", + "Here be yet more dragons!", + "Here be the mostest dragons!" }; + + private static final long serialVersionUID = 1L; + + private volatile boolean isRunning = true; + + private final int numElements; + + private transient ListState<String> unionListState; + + CheckpointingParallelSourceWithUnionListState(int numElements) { + this.numElements = numElements; + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + unionListState.clear(); + + for (String s : CHECKPOINTED_STRINGS) { + if (s.hashCode() % getRuntimeContext().getNumberOfParallelSubtasks() == getRuntimeContext().getIndexOfThisSubtask()) { + unionListState.add(s); + } + } + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + unionListState = context.getOperatorStateStore().getUnionListState( + STATE_DESCRIPTOR); + } + + @Override + public void run(SourceContext<Tuple2<Long, Long>> ctx) throws Exception { + + ctx.emitWatermark(new Watermark(0)); + + synchronized (ctx.getCheckpointLock()) { + for (long i = 0; i < numElements; i++) { + if (i % getRuntimeContext().getNumberOfParallelSubtasks() == getRuntimeContext().getIndexOfThisSubtask()) { + ctx.collect(new Tuple2<>(i, i)); + } + } + } + + // don't emit a final watermark so that we don't trigger the registered event-time + // timers + while (isRunning) { + Thread.sleep(20); + } + } + + @Override + public void cancel() { + isRunning = false; + } + } + + /** + * A parallel source with union state used to verify the restored state of + * {@link CheckpointingParallelSourceWithUnionListState}. + */ + public static class CheckingParallelSourceWithUnionListState + extends RichParallelSourceFunction<Tuple2<Long, Long>> implements CheckpointedFunction { + + private static final long serialVersionUID = 1L; + + static final String SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR = CheckingParallelSourceWithUnionListState.class + "_RESTORE_CHECK"; + + private volatile boolean isRunning = true; + + private final int numElements; + + CheckingParallelSourceWithUnionListState(int numElements) { + this.numElements = numElements; + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + ListState<String> unionListState = context.getOperatorStateStore().getUnionListState( + CheckpointingNonParallelSourceWithListState.STATE_DESCRIPTOR); + + if (context.isRestored()) { + assertThat(unionListState.get(), + containsInAnyOrder(CheckpointingParallelSourceWithUnionListState.CHECKPOINTED_STRINGS)); + + getRuntimeContext().addAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, new IntCounter()); + getRuntimeContext().getAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR).add(1); + } else { + throw new RuntimeException( + "This source should always be restored because it's only used when restoring from a savepoint."); + } + } + + @Override + public void run(SourceContext<Tuple2<Long, Long>> ctx) throws Exception { + + // immediately trigger any set timers + ctx.emitWatermark(new Watermark(1000)); + + synchronized (ctx.getCheckpointLock()) { + for (long i = 0; i < numElements; i++) { + if (i % getRuntimeContext().getNumberOfParallelSubtasks() == getRuntimeContext().getIndexOfThisSubtask()) { + ctx.collect(new Tuple2<>(i, i)); + } + } + } + + while (isRunning) { + Thread.sleep(20); + } + } + + @Override + public void cancel() { + isRunning = false; + } + } + + /** + * A sink which counts the elements it sees in an accumulator. + */ + public static class AccumulatorCountingSink<T> extends RichSinkFunction<T> { + private static final long serialVersionUID = 1L; + + static final String NUM_ELEMENTS_ACCUMULATOR = AccumulatorCountingSink.class + "_NUM_ELEMENTS"; + + int count = 0; + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + + getRuntimeContext().addAccumulator(NUM_ELEMENTS_ACCUMULATOR, new IntCounter()); + } + + @Override + public void invoke(T value, Context context) throws Exception { + count++; + getRuntimeContext().getAccumulator(NUM_ELEMENTS_ACCUMULATOR).add(1); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/SavepointMigrationTestBase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/SavepointMigrationTestBase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/SavepointMigrationTestBase.java index cfa155b..84cb88a 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/SavepointMigrationTestBase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/SavepointMigrationTestBase.java @@ -144,7 +144,13 @@ public abstract class SavepointMigrationTestBase extends TestBaseUtils { boolean allDone = true; for (Tuple2<String, Integer> acc : expectedAccumulators) { - Integer numFinished = (Integer) accumulators.get(acc.f0).get(); + OptionalFailure<Object> accumOpt = accumulators.get(acc.f0); + if (accumOpt == null) { + allDone = false; + break; + } + + Integer numFinished = (Integer) accumOpt.get(); if (numFinished == null) { allDone = false; break; http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointMigrationITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointMigrationITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointMigrationITCase.java index d2de881..c74f304 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointMigrationITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointMigrationITCase.java @@ -21,26 +21,17 @@ package org.apache.flink.test.checkpointing.utils; import org.apache.flink.api.common.accumulators.IntCounter; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.restartstrategy.RestartStrategies; -import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeHint; import org.apache.flink.api.common.typeutils.base.LongSerializer; -import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.contrib.streaming.state.RocksDBStateBackend; -import org.apache.flink.runtime.state.FunctionInitializationContext; -import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.runtime.state.StateBackendLoader; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.api.TimeCharacteristic; -import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; -import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; -import org.apache.flink.streaming.api.functions.source.RichSourceFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.InternalTimer; @@ -59,9 +50,7 @@ import org.junit.runners.Parameterized; import java.util.Arrays; import java.util.Collection; -import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertThat; /** * Migration ITCases for a stateful job. The tests are parameterized to cover @@ -130,13 +119,13 @@ public class StatefulJobSavepointMigrationITCase extends SavepointMigrationTestB OneInputStreamOperator<Tuple2<Long, Long>, Tuple2<Long, Long>> timelyOperator; if (executionMode == ExecutionMode.PERFORM_SAVEPOINT) { - nonParallelSource = new CheckpointingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS); - parallelSource = new CheckpointingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS); + nonParallelSource = new MigrationTestUtils.CheckpointingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS); + parallelSource = new MigrationTestUtils.CheckpointingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS); flatMap = new CheckpointingKeyedStateFlatMap(); timelyOperator = new CheckpointingTimelyStatefulOperator(); } else if (executionMode == ExecutionMode.VERIFY_SAVEPOINT) { - nonParallelSource = new CheckingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS); - parallelSource = new CheckingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS); + nonParallelSource = new MigrationTestUtils.CheckingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS); + parallelSource = new MigrationTestUtils.CheckingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS); flatMap = new CheckingKeyedStateFlatMap(); timelyOperator = new CheckingTimelyStatefulOperator(); } else { @@ -152,7 +141,7 @@ public class StatefulJobSavepointMigrationITCase extends SavepointMigrationTestB "timely_stateful_operator", new TypeHint<Tuple2<Long, Long>>() {}.getTypeInfo(), timelyOperator).uid("CheckpointingTimelyStatefulOperator1") - .addSink(new AccumulatorCountingSink<>()); + .addSink(new MigrationTestUtils.AccumulatorCountingSink<>()); env .addSource(parallelSource).uid("CheckpointingSource2") @@ -163,24 +152,24 @@ public class StatefulJobSavepointMigrationITCase extends SavepointMigrationTestB "timely_stateful_operator", new TypeHint<Tuple2<Long, Long>>() {}.getTypeInfo(), timelyOperator).uid("CheckpointingTimelyStatefulOperator2") - .addSink(new AccumulatorCountingSink<>()); + .addSink(new MigrationTestUtils.AccumulatorCountingSink<>()); if (executionMode == ExecutionMode.PERFORM_SAVEPOINT) { executeAndSavepoint( env, "src/test/resources/" + getSavepointPath(testMigrateVersion, testStateBackend), - new Tuple2<>(AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, NUM_SOURCE_ELEMENTS * 2)); + new Tuple2<>(MigrationTestUtils.AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, NUM_SOURCE_ELEMENTS * 2)); } else { restoreAndExecute( env, getResourceFilename(getSavepointPath(testMigrateVersion, testStateBackend)), - new Tuple2<>(CheckingNonParallelSourceWithListState.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, 1), - new Tuple2<>(CheckingParallelSourceWithUnionListState.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, parallelism), + new Tuple2<>(MigrationTestUtils.CheckingNonParallelSourceWithListState.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, 1), + new Tuple2<>(MigrationTestUtils.CheckingParallelSourceWithUnionListState.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, parallelism), new Tuple2<>(CheckingKeyedStateFlatMap.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, NUM_SOURCE_ELEMENTS * 2), new Tuple2<>(CheckingTimelyStatefulOperator.SUCCESSFUL_PROCESS_CHECK_ACCUMULATOR, NUM_SOURCE_ELEMENTS * 2), new Tuple2<>(CheckingTimelyStatefulOperator.SUCCESSFUL_EVENT_TIME_CHECK_ACCUMULATOR, NUM_SOURCE_ELEMENTS * 2), new Tuple2<>(CheckingTimelyStatefulOperator.SUCCESSFUL_PROCESSING_TIME_CHECK_ACCUMULATOR, NUM_SOURCE_ELEMENTS * 2), - new Tuple2<>(AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, NUM_SOURCE_ELEMENTS * 2)); + new Tuple2<>(MigrationTestUtils.AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, NUM_SOURCE_ELEMENTS * 2)); } } @@ -195,261 +184,6 @@ public class StatefulJobSavepointMigrationITCase extends SavepointMigrationTestB } } - private static class CheckpointingNonParallelSourceWithListState - implements SourceFunction<Tuple2<Long, Long>>, CheckpointedFunction { - - static final ListStateDescriptor<String> STATE_DESCRIPTOR = - new ListStateDescriptor<>("source-state", StringSerializer.INSTANCE); - - static final String CHECKPOINTED_STRING = "Here be dragons!"; - static final String CHECKPOINTED_STRING_1 = "Here be more dragons!"; - static final String CHECKPOINTED_STRING_2 = "Here be yet more dragons!"; - static final String CHECKPOINTED_STRING_3 = "Here be the mostest dragons!"; - - private static final long serialVersionUID = 1L; - - private volatile boolean isRunning = true; - - private final int numElements; - - private transient ListState<String> unionListState; - - CheckpointingNonParallelSourceWithListState(int numElements) { - this.numElements = numElements; - } - - @Override - public void snapshotState(FunctionSnapshotContext context) throws Exception { - unionListState.clear(); - unionListState.add(CHECKPOINTED_STRING); - unionListState.add(CHECKPOINTED_STRING_1); - unionListState.add(CHECKPOINTED_STRING_2); - unionListState.add(CHECKPOINTED_STRING_3); - } - - @Override - public void initializeState(FunctionInitializationContext context) throws Exception { - unionListState = context.getOperatorStateStore().getListState( - STATE_DESCRIPTOR); - } - - @Override - public void run(SourceContext<Tuple2<Long, Long>> ctx) throws Exception { - - ctx.emitWatermark(new Watermark(0)); - - synchronized (ctx.getCheckpointLock()) { - for (long i = 0; i < numElements; i++) { - ctx.collect(new Tuple2<>(i, i)); - } - } - - // don't emit a final watermark so that we don't trigger the registered event-time - // timers - while (isRunning) { - Thread.sleep(20); - } - } - - @Override - public void cancel() { - isRunning = false; - } - } - - private static class CheckingNonParallelSourceWithListState - extends RichSourceFunction<Tuple2<Long, Long>> implements CheckpointedFunction { - - private static final long serialVersionUID = 1L; - - static final String SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR = CheckingNonParallelSourceWithListState.class + "_RESTORE_CHECK"; - - private volatile boolean isRunning = true; - - private final int numElements; - - CheckingNonParallelSourceWithListState(int numElements) { - this.numElements = numElements; - } - - @Override - public void snapshotState(FunctionSnapshotContext context) throws Exception { - - } - - @Override - public void initializeState(FunctionInitializationContext context) throws Exception { - ListState<String> unionListState = context.getOperatorStateStore().getListState( - CheckpointingNonParallelSourceWithListState.STATE_DESCRIPTOR); - - if (context.isRestored()) { - assertThat(unionListState.get(), - containsInAnyOrder( - CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING, - CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING_1, - CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING_2, - CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING_3)); - - getRuntimeContext().addAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, new IntCounter()); - getRuntimeContext().getAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR).add(1); - } else { - throw new RuntimeException( - "This source should always be restored because it's only used when restoring from a savepoint."); - } - } - - @Override - public void run(SourceContext<Tuple2<Long, Long>> ctx) throws Exception { - - // immediately trigger any set timers - ctx.emitWatermark(new Watermark(1000)); - - synchronized (ctx.getCheckpointLock()) { - for (long i = 0; i < numElements; i++) { - ctx.collect(new Tuple2<>(i, i)); - } - } - - while (isRunning) { - Thread.sleep(20); - } - } - - @Override - public void cancel() { - isRunning = false; - } - } - - private static class CheckpointingParallelSourceWithUnionListState - extends RichSourceFunction<Tuple2<Long, Long>> implements CheckpointedFunction { - - static final ListStateDescriptor<String> STATE_DESCRIPTOR = - new ListStateDescriptor<>("source-state", StringSerializer.INSTANCE); - - static final String[] CHECKPOINTED_STRINGS = { - "Here be dragons!", - "Here be more dragons!", - "Here be yet more dragons!", - "Here be the mostest dragons!" }; - - private static final long serialVersionUID = 1L; - - private volatile boolean isRunning = true; - - private final int numElements; - - private transient ListState<String> unionListState; - - CheckpointingParallelSourceWithUnionListState(int numElements) { - this.numElements = numElements; - } - - @Override - public void snapshotState(FunctionSnapshotContext context) throws Exception { - unionListState.clear(); - - for (String s : CHECKPOINTED_STRINGS) { - if (s.hashCode() % getRuntimeContext().getNumberOfParallelSubtasks() == getRuntimeContext().getIndexOfThisSubtask()) { - unionListState.add(s); - } - } - } - - @Override - public void initializeState(FunctionInitializationContext context) throws Exception { - unionListState = context.getOperatorStateStore().getUnionListState( - STATE_DESCRIPTOR); - } - - @Override - public void run(SourceContext<Tuple2<Long, Long>> ctx) throws Exception { - - ctx.emitWatermark(new Watermark(0)); - - synchronized (ctx.getCheckpointLock()) { - for (long i = 0; i < numElements; i++) { - if (i % getRuntimeContext().getNumberOfParallelSubtasks() == getRuntimeContext().getIndexOfThisSubtask()) { - ctx.collect(new Tuple2<>(i, i)); - } - } - } - - // don't emit a final watermark so that we don't trigger the registered event-time - // timers - while (isRunning) { - Thread.sleep(20); - } - } - - @Override - public void cancel() { - isRunning = false; - } - } - - private static class CheckingParallelSourceWithUnionListState - extends RichParallelSourceFunction<Tuple2<Long, Long>> implements CheckpointedFunction { - - private static final long serialVersionUID = 1L; - - static final String SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR = CheckingParallelSourceWithUnionListState.class + "_RESTORE_CHECK"; - - private volatile boolean isRunning = true; - - private final int numElements; - - CheckingParallelSourceWithUnionListState(int numElements) { - this.numElements = numElements; - } - - @Override - public void snapshotState(FunctionSnapshotContext context) throws Exception { - - } - - @Override - public void initializeState(FunctionInitializationContext context) throws Exception { - ListState<String> unionListState = context.getOperatorStateStore().getUnionListState( - CheckpointingNonParallelSourceWithListState.STATE_DESCRIPTOR); - - if (context.isRestored()) { - assertThat(unionListState.get(), - containsInAnyOrder(CheckpointingParallelSourceWithUnionListState.CHECKPOINTED_STRINGS)); - - getRuntimeContext().addAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, new IntCounter()); - getRuntimeContext().getAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR).add(1); - } else { - throw new RuntimeException( - "This source should always be restored because it's only used when restoring from a savepoint."); - } - } - - @Override - public void run(SourceContext<Tuple2<Long, Long>> ctx) throws Exception { - - // immediately trigger any set timers - ctx.emitWatermark(new Watermark(1000)); - - synchronized (ctx.getCheckpointLock()) { - for (long i = 0; i < numElements; i++) { - if (i % getRuntimeContext().getNumberOfParallelSubtasks() == getRuntimeContext().getIndexOfThisSubtask()) { - ctx.collect(new Tuple2<>(i, i)); - } - } - } - - while (isRunning) { - Thread.sleep(20); - } - } - - @Override - public void cancel() { - isRunning = false; - } - } - private static class CheckpointingKeyedStateFlatMap extends RichFlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> { private static final long serialVersionUID = 1L; @@ -609,26 +343,4 @@ public class StatefulJobSavepointMigrationITCase extends SavepointMigrationTestB getRuntimeContext().getAccumulator(SUCCESSFUL_PROCESSING_TIME_CHECK_ACCUMULATOR).add(1); } } - - private static class AccumulatorCountingSink<T> extends RichSinkFunction<T> { - private static final long serialVersionUID = 1L; - - static final String NUM_ELEMENTS_ACCUMULATOR = AccumulatorCountingSink.class + "_NUM_ELEMENTS"; - - int count = 0; - - @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - - getRuntimeContext().addAccumulator(NUM_ELEMENTS_ACCUMULATOR, new IntCounter()); - } - - @Override - public void invoke(T value, Context context) throws Exception { - count++; - getRuntimeContext().getAccumulator(NUM_ELEMENTS_ACCUMULATOR).add(1); - } - } - } http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobWBroadcastStateMigrationITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobWBroadcastStateMigrationITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobWBroadcastStateMigrationITCase.java new file mode 100644 index 0000000..147415d --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobWBroadcastStateMigrationITCase.java @@ -0,0 +1,391 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.checkpointing.utils; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.contrib.streaming.state.RocksDBStateBackend; +import org.apache.flink.runtime.state.StateBackendLoader; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.streaming.api.TimeCharacteristic; +import org.apache.flink.streaming.api.datastream.BroadcastStream; +import org.apache.flink.streaming.api.datastream.KeyedStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.streaming.util.migration.MigrationVersion; +import org.apache.flink.util.Collector; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + +/** + * Migration ITCases for a stateful job with broadcast state. The tests are parameterized to (potentially) + * cover migrating for multiple previous Flink versions, as well as for different state backends. + */ +@RunWith(Parameterized.class) +public class StatefulJobWBroadcastStateMigrationITCase extends SavepointMigrationTestBase { + + private static final int NUM_SOURCE_ELEMENTS = 4; + + // TODO change this to PERFORM_SAVEPOINT to regenerate binary savepoints + private final StatefulJobSavepointMigrationITCase.ExecutionMode executionMode = + StatefulJobSavepointMigrationITCase.ExecutionMode.VERIFY_SAVEPOINT; + + @Parameterized.Parameters(name = "Migrate Savepoint / Backend: {0}") + public static Collection<Tuple2<MigrationVersion, String>> parameters () { + return Arrays.asList( + Tuple2.of(MigrationVersion.v1_5, StateBackendLoader.MEMORY_STATE_BACKEND_NAME), + Tuple2.of(MigrationVersion.v1_5, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME)); + } + + private final MigrationVersion testMigrateVersion; + private final String testStateBackend; + + public StatefulJobWBroadcastStateMigrationITCase(Tuple2<MigrationVersion, String> testMigrateVersionAndBackend) throws Exception { + this.testMigrateVersion = testMigrateVersionAndBackend.f0; + this.testStateBackend = testMigrateVersionAndBackend.f1; + } + + @Test + public void testSavepoint() throws Exception { + + final int parallelism = 4; + + final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setRestartStrategy(RestartStrategies.noRestart()); + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); + + switch (testStateBackend) { + case StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME: + env.setStateBackend(new RocksDBStateBackend(new MemoryStateBackend())); + break; + case StateBackendLoader.MEMORY_STATE_BACKEND_NAME: + env.setStateBackend(new MemoryStateBackend()); + break; + default: + throw new UnsupportedOperationException(); + } + + env.enableCheckpointing(500); + env.setParallelism(parallelism); + env.setMaxParallelism(parallelism); + + SourceFunction<Tuple2<Long, Long>> nonParallelSource; + SourceFunction<Tuple2<Long, Long>> nonParallelSourceB; + SourceFunction<Tuple2<Long, Long>> parallelSource; + SourceFunction<Tuple2<Long, Long>> parallelSourceB; + KeyedBroadcastProcessFunction<Long, Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> firstBroadcastFunction; + KeyedBroadcastProcessFunction<Long, Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> secondBroadcastFunction; + + final Map<Long, Long> expectedFirstState = new HashMap<>(); + expectedFirstState.put(0L, 0L); + expectedFirstState.put(1L, 1L); + expectedFirstState.put(2L, 2L); + expectedFirstState.put(3L, 3L); + + final Map<String, String> expectedSecondState = new HashMap<>(); + expectedSecondState.put("0", "0"); + expectedSecondState.put("1", "1"); + expectedSecondState.put("2", "2"); + expectedSecondState.put("3", "3"); + + final Map<String, String> expectedThirdState = new HashMap<>(); + expectedThirdState.put("0", "0"); + expectedThirdState.put("1", "1"); + expectedThirdState.put("2", "2"); + expectedThirdState.put("3", "3"); + + if (executionMode == StatefulJobSavepointMigrationITCase.ExecutionMode.PERFORM_SAVEPOINT) { + nonParallelSource = new MigrationTestUtils.CheckpointingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS); + nonParallelSourceB = new MigrationTestUtils.CheckpointingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS); + parallelSource = new MigrationTestUtils.CheckpointingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS); + parallelSourceB = new MigrationTestUtils.CheckpointingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS); + firstBroadcastFunction = new CheckpointingKeyedBroadcastFunction(); + secondBroadcastFunction = new CheckpointingKeyedSingleBroadcastFunction(); + } else if (executionMode == StatefulJobSavepointMigrationITCase.ExecutionMode.VERIFY_SAVEPOINT) { + nonParallelSource = new MigrationTestUtils.CheckingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS); + nonParallelSourceB = new MigrationTestUtils.CheckingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS); + parallelSource = new MigrationTestUtils.CheckingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS); + parallelSourceB = new MigrationTestUtils.CheckingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS); + firstBroadcastFunction = new CheckingKeyedBroadcastFunction(expectedFirstState, expectedSecondState); + secondBroadcastFunction = new CheckingKeyedSingleBroadcastFunction(expectedThirdState); + } else { + throw new IllegalStateException("Unknown ExecutionMode " + executionMode); + } + + KeyedStream<Tuple2<Long, Long>, Long> npStream = env + .addSource(nonParallelSource).uid("CheckpointingSource1") + .keyBy(new KeySelector<Tuple2<Long, Long>, Long>() { + + private static final long serialVersionUID = -4514793867774977152L; + + @Override + public Long getKey(Tuple2<Long, Long> value) throws Exception { + return value.f0; + } + }); + + KeyedStream<Tuple2<Long, Long>, Long> pStream = env + .addSource(parallelSource).uid("CheckpointingSource2") + .keyBy(new KeySelector<Tuple2<Long, Long>, Long>() { + + private static final long serialVersionUID = 4940496713319948104L; + + @Override + public Long getKey(Tuple2<Long, Long> value) throws Exception { + return value.f0; + } + }); + + final MapStateDescriptor<Long, Long> firstBroadcastStateDesc = new MapStateDescriptor<>( + "broadcast-state-1", BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO + ); + + final MapStateDescriptor<String, String> secondBroadcastStateDesc = new MapStateDescriptor<>( + "broadcast-state-2", BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO + ); + + final MapStateDescriptor<String, String> thirdBroadcastStateDesc = new MapStateDescriptor<>( + "broadcast-state-3", BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO + ); + + BroadcastStream<Tuple2<Long, Long>> npBroadcastStream = env + .addSource(nonParallelSourceB).uid("BrCheckpointingSource1") + .broadcast(firstBroadcastStateDesc, secondBroadcastStateDesc); + + BroadcastStream<Tuple2<Long, Long>> pBroadcastStream = env + .addSource(parallelSourceB).uid("BrCheckpointingSource2") + .broadcast(thirdBroadcastStateDesc); + + npStream + .connect(npBroadcastStream) + .process(firstBroadcastFunction).uid("BrProcess1") + .addSink(new MigrationTestUtils.AccumulatorCountingSink<>()); + + pStream + .connect(pBroadcastStream) + .process(secondBroadcastFunction).uid("BrProcess2") + .addSink(new MigrationTestUtils.AccumulatorCountingSink<>()); + + if (executionMode == StatefulJobSavepointMigrationITCase.ExecutionMode.PERFORM_SAVEPOINT) { + executeAndSavepoint( + env, + "src/test/resources/" + getBroadcastSavepointPath(testMigrateVersion, testStateBackend), + new Tuple2<>(MigrationTestUtils.AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, 2 * NUM_SOURCE_ELEMENTS)); + } else { + restoreAndExecute( + env, + getResourceFilename(getBroadcastSavepointPath(testMigrateVersion, testStateBackend)), + new Tuple2<>(MigrationTestUtils.CheckingNonParallelSourceWithListState.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, 2), // we have 2 sources + new Tuple2<>(MigrationTestUtils.CheckingParallelSourceWithUnionListState.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, 2 * parallelism), // we have 2 sources + new Tuple2<>(MigrationTestUtils.AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, NUM_SOURCE_ELEMENTS * 2) + ); + } + } + + private String getBroadcastSavepointPath(MigrationVersion savepointVersion, String backendType) { + switch (backendType) { + case StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME: + return "new-stateful-broadcast-udf-migration-itcase-flink" + savepointVersion + "-rocksdb-savepoint"; + case StateBackendLoader.MEMORY_STATE_BACKEND_NAME: + return "new-stateful-broadcast-udf-migration-itcase-flink" + savepointVersion + "-savepoint"; + default: + throw new UnsupportedOperationException(); + } + } + + /** + * A simple {@link KeyedBroadcastProcessFunction} that puts everything on the broadcast side in the state. + */ + private static class CheckpointingKeyedBroadcastFunction + extends KeyedBroadcastProcessFunction<Long, Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> { + + private static final long serialVersionUID = 1333992081671604521L; + + private MapStateDescriptor<Long, Long> firstStateDesc; + + private MapStateDescriptor<String, String> secondStateDesc; + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + + firstStateDesc = new MapStateDescriptor<>( + "broadcast-state-1", BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO + ); + + secondStateDesc = new MapStateDescriptor<>( + "broadcast-state-2", BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO + ); + } + + @Override + public void processElement(Tuple2<Long, Long> value, ReadOnlyContext ctx, Collector<Tuple2<Long, Long>> out) throws Exception { + out.collect(value); + } + + @Override + public void processBroadcastElement(Tuple2<Long, Long> value, Context ctx, Collector<Tuple2<Long, Long>> out) throws Exception { + ctx.getBroadcastState(firstStateDesc).put(value.f0, value.f1); + ctx.getBroadcastState(secondStateDesc).put(Long.toString(value.f0), Long.toString(value.f1)); + } + } + + /** + * A simple {@link KeyedBroadcastProcessFunction} that puts everything on the broadcast side in the state. + */ + private static class CheckpointingKeyedSingleBroadcastFunction + extends KeyedBroadcastProcessFunction<Long, Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> { + + private static final long serialVersionUID = 1333992081671604521L; + + private MapStateDescriptor<String, String> stateDesc; + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + + stateDesc = new MapStateDescriptor<>( + "broadcast-state-3", BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO + ); + } + + @Override + public void processElement(Tuple2<Long, Long> value, ReadOnlyContext ctx, Collector<Tuple2<Long, Long>> out) throws Exception { + out.collect(value); + } + + @Override + public void processBroadcastElement(Tuple2<Long, Long> value, Context ctx, Collector<Tuple2<Long, Long>> out) throws Exception { + ctx.getBroadcastState(stateDesc).put(Long.toString(value.f0), Long.toString(value.f1)); + } + } + + /** + * A simple {@link KeyedBroadcastProcessFunction} that verifies the contents of the broadcast state after recovery. + */ + private static class CheckingKeyedBroadcastFunction + extends KeyedBroadcastProcessFunction<Long, Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> { + + private static final long serialVersionUID = 1333992081671604521L; + + private final Map<Long, Long> expectedFirstState; + + private final Map<String, String> expectedSecondState; + + private MapStateDescriptor<Long, Long> firstStateDesc; + + private MapStateDescriptor<String, String> secondStateDesc; + + CheckingKeyedBroadcastFunction(Map<Long, Long> firstState, Map<String, String> secondState) { + this.expectedFirstState = firstState; + this.expectedSecondState = secondState; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + + firstStateDesc = new MapStateDescriptor<>( + "broadcast-state-1", BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO + ); + + secondStateDesc = new MapStateDescriptor<>( + "broadcast-state-2", BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO + ); + } + + @Override + public void processElement(Tuple2<Long, Long> value, ReadOnlyContext ctx, Collector<Tuple2<Long, Long>> out) throws Exception { + + final Map<Long, Long> actualFirstState = new HashMap<>(); + for (Map.Entry<Long, Long> entry: ctx.getBroadcastState(firstStateDesc).immutableEntries()) { + actualFirstState.put(entry.getKey(), entry.getValue()); + } + Assert.assertEquals(expectedFirstState, actualFirstState); + + final Map<String, String> actualSecondState = new HashMap<>(); + for (Map.Entry<String, String> entry: ctx.getBroadcastState(secondStateDesc).immutableEntries()) { + actualSecondState.put(entry.getKey(), entry.getValue()); + } + Assert.assertEquals(expectedSecondState, actualSecondState); + + out.collect(value); + } + + @Override + public void processBroadcastElement(Tuple2<Long, Long> value, Context ctx, Collector<Tuple2<Long, Long>> out) throws Exception { + // now we do nothing as we just want to verify the contents of the broadcast state. + } + } + + /** + * A simple {@link KeyedBroadcastProcessFunction} that verifies the contents of the broadcast state after recovery. + */ + private static class CheckingKeyedSingleBroadcastFunction + extends KeyedBroadcastProcessFunction<Long, Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> { + + private static final long serialVersionUID = 1333992081671604521L; + + private final Map<String, String> expectedState; + + private MapStateDescriptor<String, String> stateDesc; + + CheckingKeyedSingleBroadcastFunction(Map<String, String> state) { + this.expectedState = state; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + + stateDesc = new MapStateDescriptor<>( + "broadcast-state-3", BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO + ); + } + + @Override + public void processElement(Tuple2<Long, Long> value, ReadOnlyContext ctx, Collector<Tuple2<Long, Long>> out) throws Exception { + final Map<String, String> actualState = new HashMap<>(); + for (Map.Entry<String, String> entry: ctx.getBroadcastState(stateDesc).immutableEntries()) { + actualState.put(entry.getKey(), entry.getValue()); + } + Assert.assertEquals(expectedState, actualState); + + out.collect(value); + } + + @Override + public void processBroadcastElement(Tuple2<Long, Long> value, Context ctx, Collector<Tuple2<Long, Long>> out) throws Exception { + // now we do nothing as we just want to verify the contents of the broadcast state. + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata b/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata new file mode 100644 index 0000000..aed6183 Binary files /dev/null and b/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata differ http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-savepoint/_metadata ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-savepoint/_metadata b/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-savepoint/_metadata new file mode 100644 index 0000000..13e8d32 Binary files /dev/null and b/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-savepoint/_metadata differ http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-jobmanager-savepoint/_metadata ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-jobmanager-savepoint/_metadata b/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-jobmanager-savepoint/_metadata new file mode 100644 index 0000000..f845ee3 Binary files /dev/null and b/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-jobmanager-savepoint/_metadata differ http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata b/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata new file mode 100644 index 0000000..6b05ef9 Binary files /dev/null and b/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata differ http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/scala/org/apache/flink/api/scala/migration/StatefulJobWBroadcastStateMigrationITCase.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/migration/StatefulJobWBroadcastStateMigrationITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/migration/StatefulJobWBroadcastStateMigrationITCase.scala new file mode 100644 index 0000000..0ed705b --- /dev/null +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/migration/StatefulJobWBroadcastStateMigrationITCase.scala @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.scala.migration + +import java.util + +import org.apache.flink.api.common.accumulators.IntCounter +import org.apache.flink.api.common.functions.RichFlatMapFunction +import org.apache.flink.api.common.state._ +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.functions.KeySelector +import org.apache.flink.api.java.tuple.Tuple2 +import org.apache.flink.api.scala.createTypeInformation +import org.apache.flink.api.scala.migration.CustomEnum.CustomEnum +import org.apache.flink.configuration.Configuration +import org.apache.flink.contrib.streaming.state.RocksDBStateBackend +import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext, StateBackendLoader} +import org.apache.flink.runtime.state.memory.MemoryStateBackend +import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment +import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction +import org.apache.flink.streaming.api.functions.sink.RichSinkFunction +import org.apache.flink.streaming.api.functions.source.SourceFunction +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.streaming.util.migration.MigrationVersion +import org.apache.flink.test.checkpointing.utils.SavepointMigrationTestBase +import org.apache.flink.util.Collector +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.{Assert, Ignore, Test} + +import scala.util.{Failure, Properties, Try} + +object StatefulJobWBroadcastStateMigrationITCase { + + @Parameterized.Parameters(name = "Migrate Savepoint / Backend: {0}") + def parameters: util.Collection[(MigrationVersion, String)] = { + util.Arrays.asList( + (MigrationVersion.v1_5, StateBackendLoader.MEMORY_STATE_BACKEND_NAME), + (MigrationVersion.v1_5, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME)) + } + + // TODO to generate savepoints for a specific Flink version / backend type, + // TODO change these values accordingly, e.g. to generate for 1.3 with RocksDB, + // TODO set as (MigrationVersion.v1_3, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME) + val GENERATE_SAVEPOINT_VER: MigrationVersion = MigrationVersion.v1_5 + val GENERATE_SAVEPOINT_BACKEND_TYPE: String = StateBackendLoader.MEMORY_STATE_BACKEND_NAME + + val SCALA_VERSION: String = { + val versionString = Properties.versionString.split(" ")(1) + versionString.substring(0, versionString.lastIndexOf(".")) + } + + val NUM_ELEMENTS = 4 +} + +/** + * ITCase for migration Scala state types across different Flink versions. + */ +@RunWith(classOf[Parameterized]) +class StatefulJobWBroadcastStateMigrationITCase( + migrationVersionAndBackend: (MigrationVersion, String)) + extends SavepointMigrationTestBase with Serializable { + + @Test + @Ignore + def testCreateSavepointWithBroadcastState(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + + StatefulJobWBroadcastStateMigrationITCase.GENERATE_SAVEPOINT_BACKEND_TYPE match { + case StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME => + env.setStateBackend(new RocksDBStateBackend(new MemoryStateBackend())) + case StateBackendLoader.MEMORY_STATE_BACKEND_NAME => + env.setStateBackend(new MemoryStateBackend()) + case _ => throw new UnsupportedOperationException + } + + lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long]( + "broadcast-state-1", + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]]) + + lazy val secondBroadcastStateDesc = new MapStateDescriptor[String, String]( + "broadcast-state-2", + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO) + + env.setStateBackend(new MemoryStateBackend) + env.enableCheckpointing(500) + env.setParallelism(4) + env.setMaxParallelism(4) + + val stream = env + .addSource( + new CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedSource") + .keyBy( + new KeySelector[(Long, Long), Long] { + override def getKey(value: (Long, Long)): Long = value._1 + } + ) + .flatMap(new StatefulFlatMapper) + .keyBy( + new KeySelector[(Long, Long), Long] { + override def getKey(value: (Long, Long)): Long = value._1 + } + ) + + val broadcastStream = env + .addSource( + new CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedBroadcastSource") + .broadcast(firstBroadcastStateDesc, secondBroadcastStateDesc) + + stream + .connect(broadcastStream) + .process(new TestBroadcastProcessFunction) + .addSink(new AccumulatorCountingSink) + + executeAndSavepoint( + env, + s"src/test/resources/stateful-scala-with-broadcast" + + s"${StatefulJobWBroadcastStateMigrationITCase.SCALA_VERSION}" + + s"-udf-migration-itcase-flink" + + s"${StatefulJobWBroadcastStateMigrationITCase.GENERATE_SAVEPOINT_VER}" + + s"-${StatefulJobWBroadcastStateMigrationITCase.GENERATE_SAVEPOINT_BACKEND_TYPE}-savepoint", + new Tuple2( + AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, + StatefulJobWBroadcastStateMigrationITCase.NUM_ELEMENTS + ) + ) + } + + @Test + def testRestoreSavepointWithBroadcast(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + + migrationVersionAndBackend._2 match { + case StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME => + env.setStateBackend(new RocksDBStateBackend(new MemoryStateBackend())) + case StateBackendLoader.MEMORY_STATE_BACKEND_NAME => + env.setStateBackend(new MemoryStateBackend()) + case _ => throw new UnsupportedOperationException + } + + lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long]( + "broadcast-state-1", + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]]) + + lazy val secondBroadcastStateDesc = new MapStateDescriptor[String, String]( + "broadcast-state-2", + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO) + + env.setStateBackend(new MemoryStateBackend) + env.enableCheckpointing(500) + env.setParallelism(4) + env.setMaxParallelism(4) + + val stream = env + .addSource( + new CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedSource") + .keyBy( + new KeySelector[(Long, Long), Long] { + override def getKey(value: (Long, Long)): Long = value._1 + } + ) + .flatMap(new StatefulFlatMapper) + .keyBy( + new KeySelector[(Long, Long), Long] { + override def getKey(value: (Long, Long)): Long = value._1 + } + ) + + val broadcastStream = env + .addSource( + new CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedBroadcastSource") + .broadcast(firstBroadcastStateDesc, secondBroadcastStateDesc) + + val expectedFirstState: Map[Long, Long] = + Map(0L -> 0L, 1L -> 1L, 2L -> 2L, 3L -> 3L) + val expectedSecondState: Map[String, String] = + Map("0" -> "0", "1" -> "1", "2" -> "2", "3" -> "3") + + stream + .connect(broadcastStream) + .process(new VerifyingBroadcastProcessFunction(expectedFirstState, expectedSecondState)) + .addSink(new AccumulatorCountingSink) + + restoreAndExecute( + env, + SavepointMigrationTestBase.getResourceFilename( + s"stateful-scala-with-broadcast${StatefulJobWBroadcastStateMigrationITCase.SCALA_VERSION}" + + s"-udf-migration-itcase-flink${migrationVersionAndBackend._1}" + + s"-${migrationVersionAndBackend._2}-savepoint"), + new Tuple2( + AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, + StatefulJobWBroadcastStateMigrationITCase.NUM_ELEMENTS) + ) + } +} + +class TestBroadcastProcessFunction + extends KeyedBroadcastProcessFunction + [Long, (Long, Long), (Long, Long), (Long, Long)] { + + lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long]( + "broadcast-state-1", + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]]) + + val secondBroadcastStateDesc = new MapStateDescriptor[String, String]( + "broadcast-state-2", + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO) + + @throws[Exception] + override def processElement( + value: (Long, Long), + ctx: KeyedBroadcastProcessFunction + [Long, (Long, Long), (Long, Long), (Long, Long)]#ReadOnlyContext, + out: Collector[(Long, Long)]): Unit = { + + out.collect(value) + } + + @throws[Exception] + override def processBroadcastElement( + value: (Long, Long), + ctx: KeyedBroadcastProcessFunction + [Long, (Long, Long), (Long, Long), (Long, Long)]#Context, + out: Collector[(Long, Long)]): Unit = { + + ctx.getBroadcastState(firstBroadcastStateDesc).put(value._1, value._2) + ctx.getBroadcastState(secondBroadcastStateDesc).put(value._1.toString, value._2.toString) + } +} + +@SerialVersionUID(1L) +private object CheckpointedSource { + var CHECKPOINTED_STRING = "Here be dragons!" +} + +@SerialVersionUID(1L) +private class CheckpointedSource(val numElements: Int) + extends SourceFunction[(Long, Long)] with CheckpointedFunction { + + private var isRunning = true + private var state: ListState[CustomCaseClass] = _ + + @throws[Exception] + override def run(ctx: SourceFunction.SourceContext[(Long, Long)]) { + ctx.emitWatermark(new Watermark(0)) + ctx.getCheckpointLock synchronized { + var i = 0 + while (i < numElements) { + ctx.collect(i, i) + i += 1 + } + } + // don't emit a final watermark so that we don't trigger the registered event-time + // timers + while (isRunning) Thread.sleep(20) + } + + def cancel() { + isRunning = false + } + + override def initializeState(context: FunctionInitializationContext): Unit = { + state = context.getOperatorStateStore.getOperatorState( + new ListStateDescriptor[CustomCaseClass]( + "sourceState", createTypeInformation[CustomCaseClass])) + } + + override def snapshotState(context: FunctionSnapshotContext): Unit = { + state.clear() + state.add(CustomCaseClass("Here be dragons!", 123)) + } +} + +@SerialVersionUID(1L) +private object AccumulatorCountingSink { + var NUM_ELEMENTS_ACCUMULATOR = classOf[AccumulatorCountingSink[_]] + "_NUM_ELEMENTS" +} + +@SerialVersionUID(1L) +private class AccumulatorCountingSink[T] extends RichSinkFunction[T] { + + private var count: Int = 0 + + @throws[Exception] + override def open(parameters: Configuration) { + super.open(parameters) + getRuntimeContext.addAccumulator( + AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, new IntCounter) + } + + @throws[Exception] + def invoke(value: T) { + count += 1 + getRuntimeContext.getAccumulator( + AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR).add(1) + } +} + +class StatefulFlatMapper extends RichFlatMapFunction[(Long, Long), (Long, Long)] { + + private var caseClassState: ValueState[CustomCaseClass] = _ + private var caseClassWithNestingState: ValueState[CustomCaseClassWithNesting] = _ + private var collectionState: ValueState[List[CustomCaseClass]] = _ + private var tryState: ValueState[Try[CustomCaseClass]] = _ + private var tryFailureState: ValueState[Try[CustomCaseClass]] = _ + private var optionState: ValueState[Option[CustomCaseClass]] = _ + private var optionNoneState: ValueState[Option[CustomCaseClass]] = _ + private var eitherLeftState: ValueState[Either[CustomCaseClass, String]] = _ + private var eitherRightState: ValueState[Either[CustomCaseClass, String]] = _ + private var enumOneState: ValueState[CustomEnum] = _ + private var enumThreeState: ValueState[CustomEnum] = _ + + override def open(parameters: Configuration): Unit = { + caseClassState = getRuntimeContext.getState( + new ValueStateDescriptor[CustomCaseClass]( + "caseClassState", createTypeInformation[CustomCaseClass])) + caseClassWithNestingState = getRuntimeContext.getState( + new ValueStateDescriptor[CustomCaseClassWithNesting]( + "caseClassWithNestingState", createTypeInformation[CustomCaseClassWithNesting])) + collectionState = getRuntimeContext.getState( + new ValueStateDescriptor[List[CustomCaseClass]]( + "collectionState", createTypeInformation[List[CustomCaseClass]])) + tryState = getRuntimeContext.getState( + new ValueStateDescriptor[Try[CustomCaseClass]]( + "tryState", createTypeInformation[Try[CustomCaseClass]])) + tryFailureState = getRuntimeContext.getState( + new ValueStateDescriptor[Try[CustomCaseClass]]( + "tryFailureState", createTypeInformation[Try[CustomCaseClass]])) + optionState = getRuntimeContext.getState( + new ValueStateDescriptor[Option[CustomCaseClass]]( + "optionState", createTypeInformation[Option[CustomCaseClass]])) + optionNoneState = getRuntimeContext.getState( + new ValueStateDescriptor[Option[CustomCaseClass]]( + "optionNoneState", createTypeInformation[Option[CustomCaseClass]])) + eitherLeftState = getRuntimeContext.getState( + new ValueStateDescriptor[Either[CustomCaseClass, String]]( + "eitherLeftState", createTypeInformation[Either[CustomCaseClass, String]])) + eitherRightState = getRuntimeContext.getState( + new ValueStateDescriptor[Either[CustomCaseClass, String]]( + "eitherRightState", createTypeInformation[Either[CustomCaseClass, String]])) + enumOneState = getRuntimeContext.getState( + new ValueStateDescriptor[CustomEnum]( + "enumOneState", createTypeInformation[CustomEnum])) + enumThreeState = getRuntimeContext.getState( + new ValueStateDescriptor[CustomEnum]( + "enumThreeState", createTypeInformation[CustomEnum])) + } + + override def flatMap(in: (Long, Long), collector: Collector[(Long, Long)]): Unit = { + caseClassState.update(CustomCaseClass(in._1.toString, in._2 * 2)) + caseClassWithNestingState.update( + CustomCaseClassWithNesting(in._1, CustomCaseClass(in._1.toString, in._2 * 2))) + collectionState.update(List(CustomCaseClass(in._1.toString, in._2 * 2))) + tryState.update(Try(CustomCaseClass(in._1.toString, in._2 * 5))) + tryFailureState.update(Failure(new RuntimeException)) + optionState.update(Some(CustomCaseClass(in._1.toString, in._2 * 2))) + optionNoneState.update(None) + eitherLeftState.update(Left(CustomCaseClass(in._1.toString, in._2 * 2))) + eitherRightState.update(Right((in._1 * 3).toString)) + enumOneState.update(CustomEnum.ONE) + enumOneState.update(CustomEnum.THREE) + + collector.collect(in) + } +} + +class VerifyingBroadcastProcessFunction( + firstExpectedBroadcastState: Map[Long, Long], + secondExpectedBroadcastState: Map[String, String]) + extends KeyedBroadcastProcessFunction + [Long, (Long, Long), (Long, Long), (Long, Long)] { + + lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long]( + "broadcast-state-1", + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]]) + + val secondBroadcastStateDesc = new MapStateDescriptor[String, String]( + "broadcast-state-2", + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO) + + @throws[Exception] + override def processElement( + value: (Long, Long), + ctx: KeyedBroadcastProcessFunction + [Long, (Long, Long), (Long, Long), (Long, Long)]#ReadOnlyContext, + out: Collector[(Long, Long)]): Unit = { + + var actualFirstState = Map[Long, Long]() + + import scala.collection.JavaConversions._ + for (entry <- ctx.getBroadcastState(firstBroadcastStateDesc).immutableEntries()) { + val v = firstExpectedBroadcastState.get(entry.getKey).get + Assert.assertEquals(v, entry.getValue) + actualFirstState += (entry.getKey -> entry.getValue) + } + + Assert.assertEquals(firstExpectedBroadcastState, actualFirstState) + + var actualSecondState = Map[String, String]() + + import scala.collection.JavaConversions._ + for (entry <- ctx.getBroadcastState(secondBroadcastStateDesc).immutableEntries()) { + val v = secondExpectedBroadcastState.get(entry.getKey).get + Assert.assertEquals(v, entry.getValue) + actualSecondState += (entry.getKey -> entry.getValue) + } + + Assert.assertEquals(secondExpectedBroadcastState, actualSecondState) + out.collect(value) + } + + @throws[Exception] + override def processBroadcastElement( + value: (Long, Long), + ctx: KeyedBroadcastProcessFunction + [Long, (Long, Long), (Long, Long), (Long, Long)]#Context, + out: Collector[(Long, Long)]): Unit = { + + // do nothing + } +}
