Repository: flink
Updated Branches:
  refs/heads/release-1.5 5df2bc5c9 -> 0c06852b3


[FLINK-8659] Add migration itcases for broadcast state.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/0c06852b
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/0c06852b
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/0c06852b

Branch: refs/heads/release-1.5
Commit: 0c06852b3cecd414aac623ffd155ebc2a2a31336
Parents: 5df2bc5
Author: kkloudas <[email protected]>
Authored: Thu May 3 10:05:13 2018 +0200
Committer: kkloudas <[email protected]>
Committed: Fri May 18 15:10:54 2018 +0200

----------------------------------------------------------------------
 .../checkpointing/utils/MigrationTestUtils.java | 336 ++++++++++++++
 .../utils/SavepointMigrationTestBase.java       |   8 +-
 .../StatefulJobSavepointMigrationITCase.java    | 308 +------------
 ...atefulJobWBroadcastStateMigrationITCase.java | 391 ++++++++++++++++
 .../_metadata                                   | Bin 0 -> 20936 bytes
 .../_metadata                                   | Bin 0 -> 20936 bytes
 .../_metadata                                   | Bin 0 -> 220470 bytes
 .../_metadata                                   | Bin 0 -> 220470 bytes
 ...tefulJobWBroadcastStateMigrationITCase.scala | 450 +++++++++++++++++++
 9 files changed, 1194 insertions(+), 299 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/MigrationTestUtils.java
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/MigrationTestUtils.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/MigrationTestUtils.java
new file mode 100644
index 0000000..9314496
--- /dev/null
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/MigrationTestUtils.java
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.checkpointing.utils;
+
+import org.apache.flink.api.common.accumulators.IntCounter;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
+import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.junit.Assert.assertThat;
+
+/**
+ * A utility class containing common functions/classes used by multiple 
migration tests.
+ */
+public class MigrationTestUtils {
+
+       /**
+        * A non-parallel source with list state used for testing.
+        */
+       public static class CheckpointingNonParallelSourceWithListState
+                       implements SourceFunction<Tuple2<Long, Long>>, 
CheckpointedFunction {
+
+               static final ListStateDescriptor<String> STATE_DESCRIPTOR =
+                               new ListStateDescriptor<>("source-state", 
StringSerializer.INSTANCE);
+
+               static final String CHECKPOINTED_STRING = "Here be dragons!";
+               static final String CHECKPOINTED_STRING_1 = "Here be more 
dragons!";
+               static final String CHECKPOINTED_STRING_2 = "Here be yet more 
dragons!";
+               static final String CHECKPOINTED_STRING_3 = "Here be the 
mostest dragons!";
+
+               private static final long serialVersionUID = 1L;
+
+               private volatile boolean isRunning = true;
+
+               private final int numElements;
+
+               private transient ListState<String> unionListState;
+
+               CheckpointingNonParallelSourceWithListState(int numElements) {
+                       this.numElements = numElements;
+               }
+
+               @Override
+               public void snapshotState(FunctionSnapshotContext context) 
throws Exception {
+                       unionListState.clear();
+                       unionListState.add(CHECKPOINTED_STRING);
+                       unionListState.add(CHECKPOINTED_STRING_1);
+                       unionListState.add(CHECKPOINTED_STRING_2);
+                       unionListState.add(CHECKPOINTED_STRING_3);
+               }
+
+               @Override
+               public void initializeState(FunctionInitializationContext 
context) throws Exception {
+                       unionListState = 
context.getOperatorStateStore().getListState(
+                                       STATE_DESCRIPTOR);
+               }
+
+               @Override
+               public void run(SourceContext<Tuple2<Long, Long>> ctx) throws 
Exception {
+
+                       ctx.emitWatermark(new Watermark(0));
+
+                       synchronized (ctx.getCheckpointLock()) {
+                               for (long i = 0; i < numElements; i++) {
+                                       ctx.collect(new Tuple2<>(i, i));
+                               }
+                       }
+
+                       // don't emit a final watermark so that we don't 
trigger the registered event-time
+                       // timers
+                       while (isRunning) {
+                               Thread.sleep(20);
+                       }
+               }
+
+               @Override
+               public void cancel() {
+                       isRunning = false;
+               }
+       }
+
+       /**
+        * A non-parallel source with union state used to verify the restored 
state of
+        * {@link CheckpointingNonParallelSourceWithListState}.
+        */
+       public static class CheckingNonParallelSourceWithListState
+                       extends RichSourceFunction<Tuple2<Long, Long>> 
implements CheckpointedFunction {
+
+               private static final long serialVersionUID = 1L;
+
+               static final String SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR = 
CheckingNonParallelSourceWithListState.class + "_RESTORE_CHECK";
+
+               private volatile boolean isRunning = true;
+
+               private final int numElements;
+
+               CheckingNonParallelSourceWithListState(int numElements) {
+                       this.numElements = numElements;
+               }
+
+               @Override
+               public void snapshotState(FunctionSnapshotContext context) 
throws Exception {
+
+               }
+
+               @Override
+               public void initializeState(FunctionInitializationContext 
context) throws Exception {
+                       ListState<String> unionListState = 
context.getOperatorStateStore().getListState(
+                                       
CheckpointingNonParallelSourceWithListState.STATE_DESCRIPTOR);
+
+                       if (context.isRestored()) {
+                               assertThat(unionListState.get(),
+                                               containsInAnyOrder(
+                                                               
CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING,
+                                                               
CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING_1,
+                                                               
CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING_2,
+                                                               
CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING_3));
+
+                               
getRuntimeContext().addAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, new 
IntCounter());
+                               
getRuntimeContext().getAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR).add(1);
+                       } else {
+                               throw new RuntimeException(
+                                               "This source should always be 
restored because it's only used when restoring from a savepoint.");
+                       }
+               }
+
+               @Override
+               public void run(SourceContext<Tuple2<Long, Long>> ctx) throws 
Exception {
+
+                       // immediately trigger any set timers
+                       ctx.emitWatermark(new Watermark(1000));
+
+                       synchronized (ctx.getCheckpointLock()) {
+                               for (long i = 0; i < numElements; i++) {
+                                       ctx.collect(new Tuple2<>(i, i));
+                               }
+                       }
+
+                       while (isRunning) {
+                               Thread.sleep(20);
+                       }
+               }
+
+               @Override
+               public void cancel() {
+                       isRunning = false;
+               }
+       }
+
+       /**
+        * A parallel source with union state used for testing.
+        */
+       public static class CheckpointingParallelSourceWithUnionListState
+                       extends RichSourceFunction<Tuple2<Long, Long>> 
implements CheckpointedFunction {
+
+               static final ListStateDescriptor<String> STATE_DESCRIPTOR =
+                               new ListStateDescriptor<>("source-state", 
StringSerializer.INSTANCE);
+
+               static final String[] CHECKPOINTED_STRINGS = {
+                               "Here be dragons!",
+                               "Here be more dragons!",
+                               "Here be yet more dragons!",
+                               "Here be the mostest dragons!" };
+
+               private static final long serialVersionUID = 1L;
+
+               private volatile boolean isRunning = true;
+
+               private final int numElements;
+
+               private transient ListState<String> unionListState;
+
+               CheckpointingParallelSourceWithUnionListState(int numElements) {
+                       this.numElements = numElements;
+               }
+
+               @Override
+               public void snapshotState(FunctionSnapshotContext context) 
throws Exception {
+                       unionListState.clear();
+
+                       for (String s : CHECKPOINTED_STRINGS) {
+                               if (s.hashCode() % 
getRuntimeContext().getNumberOfParallelSubtasks() == 
getRuntimeContext().getIndexOfThisSubtask()) {
+                                       unionListState.add(s);
+                               }
+                       }
+               }
+
+               @Override
+               public void initializeState(FunctionInitializationContext 
context) throws Exception {
+                       unionListState = 
context.getOperatorStateStore().getUnionListState(
+                                       STATE_DESCRIPTOR);
+               }
+
+               @Override
+               public void run(SourceContext<Tuple2<Long, Long>> ctx) throws 
Exception {
+
+                       ctx.emitWatermark(new Watermark(0));
+
+                       synchronized (ctx.getCheckpointLock()) {
+                               for (long i = 0; i < numElements; i++) {
+                                       if (i % 
getRuntimeContext().getNumberOfParallelSubtasks() == 
getRuntimeContext().getIndexOfThisSubtask()) {
+                                               ctx.collect(new Tuple2<>(i, i));
+                                       }
+                               }
+                       }
+
+                       // don't emit a final watermark so that we don't 
trigger the registered event-time
+                       // timers
+                       while (isRunning) {
+                               Thread.sleep(20);
+                       }
+               }
+
+               @Override
+               public void cancel() {
+                       isRunning = false;
+               }
+       }
+
+       /**
+        * A parallel source with union state used to verify the restored state 
of
+        * {@link CheckpointingParallelSourceWithUnionListState}.
+        */
+       public static class CheckingParallelSourceWithUnionListState
+                       extends RichParallelSourceFunction<Tuple2<Long, Long>> 
implements CheckpointedFunction {
+
+               private static final long serialVersionUID = 1L;
+
+               static final String SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR = 
CheckingParallelSourceWithUnionListState.class + "_RESTORE_CHECK";
+
+               private volatile boolean isRunning = true;
+
+               private final int numElements;
+
+               CheckingParallelSourceWithUnionListState(int numElements) {
+                       this.numElements = numElements;
+               }
+
+               @Override
+               public void snapshotState(FunctionSnapshotContext context) 
throws Exception {
+
+               }
+
+               @Override
+               public void initializeState(FunctionInitializationContext 
context) throws Exception {
+                       ListState<String> unionListState = 
context.getOperatorStateStore().getUnionListState(
+                                       
CheckpointingNonParallelSourceWithListState.STATE_DESCRIPTOR);
+
+                       if (context.isRestored()) {
+                               assertThat(unionListState.get(),
+                                               
containsInAnyOrder(CheckpointingParallelSourceWithUnionListState.CHECKPOINTED_STRINGS));
+
+                               
getRuntimeContext().addAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, new 
IntCounter());
+                               
getRuntimeContext().getAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR).add(1);
+                       } else {
+                               throw new RuntimeException(
+                                               "This source should always be 
restored because it's only used when restoring from a savepoint.");
+                       }
+               }
+
+               @Override
+               public void run(SourceContext<Tuple2<Long, Long>> ctx) throws 
Exception {
+
+                       // immediately trigger any set timers
+                       ctx.emitWatermark(new Watermark(1000));
+
+                       synchronized (ctx.getCheckpointLock()) {
+                               for (long i = 0; i < numElements; i++) {
+                                       if (i % 
getRuntimeContext().getNumberOfParallelSubtasks() == 
getRuntimeContext().getIndexOfThisSubtask()) {
+                                               ctx.collect(new Tuple2<>(i, i));
+                                       }
+                               }
+                       }
+
+                       while (isRunning) {
+                               Thread.sleep(20);
+                       }
+               }
+
+               @Override
+               public void cancel() {
+                       isRunning = false;
+               }
+       }
+
+       /**
+        * A sink which counts the elements it sees in an accumulator.
+        */
+       public static class AccumulatorCountingSink<T> extends 
RichSinkFunction<T> {
+               private static final long serialVersionUID = 1L;
+
+               static final String NUM_ELEMENTS_ACCUMULATOR = 
AccumulatorCountingSink.class + "_NUM_ELEMENTS";
+
+               int count = 0;
+
+               @Override
+               public void open(Configuration parameters) throws Exception {
+                       super.open(parameters);
+
+                       
getRuntimeContext().addAccumulator(NUM_ELEMENTS_ACCUMULATOR, new IntCounter());
+               }
+
+               @Override
+               public void invoke(T value, Context context) throws Exception {
+                       count++;
+                       
getRuntimeContext().getAccumulator(NUM_ELEMENTS_ACCUMULATOR).add(1);
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/SavepointMigrationTestBase.java
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/SavepointMigrationTestBase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/SavepointMigrationTestBase.java
index cfa155b..84cb88a 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/SavepointMigrationTestBase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/SavepointMigrationTestBase.java
@@ -144,7 +144,13 @@ public abstract class SavepointMigrationTestBase extends 
TestBaseUtils {
 
                        boolean allDone = true;
                        for (Tuple2<String, Integer> acc : 
expectedAccumulators) {
-                               Integer numFinished = (Integer) 
accumulators.get(acc.f0).get();
+                               OptionalFailure<Object> accumOpt = 
accumulators.get(acc.f0);
+                               if (accumOpt == null) {
+                                       allDone = false;
+                                       break;
+                               }
+
+                               Integer numFinished = (Integer) accumOpt.get();
                                if (numFinished == null) {
                                        allDone = false;
                                        break;

http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointMigrationITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointMigrationITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointMigrationITCase.java
index d2de881..c74f304 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointMigrationITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointMigrationITCase.java
@@ -21,26 +21,17 @@ package org.apache.flink.test.checkpointing.utils;
 import org.apache.flink.api.common.accumulators.IntCounter;
 import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
-import org.apache.flink.api.common.state.ListState;
-import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeinfo.TypeHint;
 import org.apache.flink.api.common.typeutils.base.LongSerializer;
-import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
-import org.apache.flink.runtime.state.FunctionInitializationContext;
-import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.runtime.state.StateBackendLoader;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.streaming.api.TimeCharacteristic;
-import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
-import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
-import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.InternalTimer;
@@ -59,9 +50,7 @@ import org.junit.runners.Parameterized;
 import java.util.Arrays;
 import java.util.Collection;
 
-import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertThat;
 
 /**
  * Migration ITCases for a stateful job. The tests are parameterized to cover
@@ -130,13 +119,13 @@ public class StatefulJobSavepointMigrationITCase extends 
SavepointMigrationTestB
                OneInputStreamOperator<Tuple2<Long, Long>, Tuple2<Long, Long>> 
timelyOperator;
 
                if (executionMode == ExecutionMode.PERFORM_SAVEPOINT) {
-                       nonParallelSource = new 
CheckpointingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS);
-                       parallelSource = new 
CheckpointingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS);
+                       nonParallelSource = new 
MigrationTestUtils.CheckpointingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS);
+                       parallelSource = new 
MigrationTestUtils.CheckpointingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS);
                        flatMap = new CheckpointingKeyedStateFlatMap();
                        timelyOperator = new 
CheckpointingTimelyStatefulOperator();
                } else if (executionMode == ExecutionMode.VERIFY_SAVEPOINT) {
-                       nonParallelSource = new 
CheckingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS);
-                       parallelSource = new 
CheckingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS);
+                       nonParallelSource = new 
MigrationTestUtils.CheckingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS);
+                       parallelSource = new 
MigrationTestUtils.CheckingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS);
                        flatMap = new CheckingKeyedStateFlatMap();
                        timelyOperator = new CheckingTimelyStatefulOperator();
                } else {
@@ -152,7 +141,7 @@ public class StatefulJobSavepointMigrationITCase extends 
SavepointMigrationTestB
                                "timely_stateful_operator",
                                new TypeHint<Tuple2<Long, Long>>() 
{}.getTypeInfo(),
                                
timelyOperator).uid("CheckpointingTimelyStatefulOperator1")
-                       .addSink(new AccumulatorCountingSink<>());
+                       .addSink(new 
MigrationTestUtils.AccumulatorCountingSink<>());
 
                env
                        .addSource(parallelSource).uid("CheckpointingSource2")
@@ -163,24 +152,24 @@ public class StatefulJobSavepointMigrationITCase extends 
SavepointMigrationTestB
                                "timely_stateful_operator",
                                new TypeHint<Tuple2<Long, Long>>() 
{}.getTypeInfo(),
                                
timelyOperator).uid("CheckpointingTimelyStatefulOperator2")
-                       .addSink(new AccumulatorCountingSink<>());
+                       .addSink(new 
MigrationTestUtils.AccumulatorCountingSink<>());
 
                if (executionMode == ExecutionMode.PERFORM_SAVEPOINT) {
                        executeAndSavepoint(
                                env,
                                "src/test/resources/" + 
getSavepointPath(testMigrateVersion, testStateBackend),
-                               new 
Tuple2<>(AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, NUM_SOURCE_ELEMENTS 
* 2));
+                               new 
Tuple2<>(MigrationTestUtils.AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, 
NUM_SOURCE_ELEMENTS * 2));
                } else {
                        restoreAndExecute(
                                env,
                                
getResourceFilename(getSavepointPath(testMigrateVersion, testStateBackend)),
-                               new 
Tuple2<>(CheckingNonParallelSourceWithListState.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR,
 1),
-                               new 
Tuple2<>(CheckingParallelSourceWithUnionListState.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR,
 parallelism),
+                               new 
Tuple2<>(MigrationTestUtils.CheckingNonParallelSourceWithListState.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR,
 1),
+                               new 
Tuple2<>(MigrationTestUtils.CheckingParallelSourceWithUnionListState.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR,
 parallelism),
                                new 
Tuple2<>(CheckingKeyedStateFlatMap.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, 
NUM_SOURCE_ELEMENTS * 2),
                                new 
Tuple2<>(CheckingTimelyStatefulOperator.SUCCESSFUL_PROCESS_CHECK_ACCUMULATOR, 
NUM_SOURCE_ELEMENTS * 2),
                                new 
Tuple2<>(CheckingTimelyStatefulOperator.SUCCESSFUL_EVENT_TIME_CHECK_ACCUMULATOR,
 NUM_SOURCE_ELEMENTS * 2),
                                new 
Tuple2<>(CheckingTimelyStatefulOperator.SUCCESSFUL_PROCESSING_TIME_CHECK_ACCUMULATOR,
 NUM_SOURCE_ELEMENTS * 2),
-                               new 
Tuple2<>(AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, NUM_SOURCE_ELEMENTS 
* 2));
+                               new 
Tuple2<>(MigrationTestUtils.AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, 
NUM_SOURCE_ELEMENTS * 2));
                }
        }
 
@@ -195,261 +184,6 @@ public class StatefulJobSavepointMigrationITCase extends 
SavepointMigrationTestB
                }
        }
 
-       private static class CheckpointingNonParallelSourceWithListState
-               implements SourceFunction<Tuple2<Long, Long>>, 
CheckpointedFunction {
-
-               static final ListStateDescriptor<String> STATE_DESCRIPTOR =
-                       new ListStateDescriptor<>("source-state", 
StringSerializer.INSTANCE);
-
-               static final String CHECKPOINTED_STRING = "Here be dragons!";
-               static final String CHECKPOINTED_STRING_1 = "Here be more 
dragons!";
-               static final String CHECKPOINTED_STRING_2 = "Here be yet more 
dragons!";
-               static final String CHECKPOINTED_STRING_3 = "Here be the 
mostest dragons!";
-
-               private static final long serialVersionUID = 1L;
-
-               private volatile boolean isRunning = true;
-
-               private final int numElements;
-
-               private transient ListState<String> unionListState;
-
-               CheckpointingNonParallelSourceWithListState(int numElements) {
-                       this.numElements = numElements;
-               }
-
-               @Override
-               public void snapshotState(FunctionSnapshotContext context) 
throws Exception {
-                       unionListState.clear();
-                       unionListState.add(CHECKPOINTED_STRING);
-                       unionListState.add(CHECKPOINTED_STRING_1);
-                       unionListState.add(CHECKPOINTED_STRING_2);
-                       unionListState.add(CHECKPOINTED_STRING_3);
-               }
-
-               @Override
-               public void initializeState(FunctionInitializationContext 
context) throws Exception {
-                       unionListState = 
context.getOperatorStateStore().getListState(
-                               STATE_DESCRIPTOR);
-               }
-
-               @Override
-               public void run(SourceContext<Tuple2<Long, Long>> ctx) throws 
Exception {
-
-                       ctx.emitWatermark(new Watermark(0));
-
-                       synchronized (ctx.getCheckpointLock()) {
-                               for (long i = 0; i < numElements; i++) {
-                                       ctx.collect(new Tuple2<>(i, i));
-                               }
-                       }
-
-                       // don't emit a final watermark so that we don't 
trigger the registered event-time
-                       // timers
-                       while (isRunning) {
-                               Thread.sleep(20);
-                       }
-               }
-
-               @Override
-               public void cancel() {
-                       isRunning = false;
-               }
-       }
-
-       private static class CheckingNonParallelSourceWithListState
-               extends RichSourceFunction<Tuple2<Long, Long>> implements 
CheckpointedFunction {
-
-               private static final long serialVersionUID = 1L;
-
-               static final String SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR = 
CheckingNonParallelSourceWithListState.class + "_RESTORE_CHECK";
-
-               private volatile boolean isRunning = true;
-
-               private final int numElements;
-
-               CheckingNonParallelSourceWithListState(int numElements) {
-                       this.numElements = numElements;
-               }
-
-               @Override
-               public void snapshotState(FunctionSnapshotContext context) 
throws Exception {
-
-               }
-
-               @Override
-               public void initializeState(FunctionInitializationContext 
context) throws Exception {
-                       ListState<String> unionListState = 
context.getOperatorStateStore().getListState(
-                               
CheckpointingNonParallelSourceWithListState.STATE_DESCRIPTOR);
-
-                       if (context.isRestored()) {
-                               assertThat(unionListState.get(),
-                                       containsInAnyOrder(
-                                               
CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING,
-                                               
CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING_1,
-                                               
CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING_2,
-                                               
CheckpointingNonParallelSourceWithListState.CHECKPOINTED_STRING_3));
-
-                               
getRuntimeContext().addAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, new 
IntCounter());
-                               
getRuntimeContext().getAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR).add(1);
-                       } else {
-                               throw new RuntimeException(
-                                       "This source should always be restored 
because it's only used when restoring from a savepoint.");
-                       }
-               }
-
-               @Override
-               public void run(SourceContext<Tuple2<Long, Long>> ctx) throws 
Exception {
-
-                       // immediately trigger any set timers
-                       ctx.emitWatermark(new Watermark(1000));
-
-                       synchronized (ctx.getCheckpointLock()) {
-                               for (long i = 0; i < numElements; i++) {
-                                       ctx.collect(new Tuple2<>(i, i));
-                               }
-                       }
-
-                       while (isRunning) {
-                               Thread.sleep(20);
-                       }
-               }
-
-               @Override
-               public void cancel() {
-                       isRunning = false;
-               }
-       }
-
-       private static class CheckpointingParallelSourceWithUnionListState
-               extends RichSourceFunction<Tuple2<Long, Long>> implements 
CheckpointedFunction {
-
-               static final ListStateDescriptor<String> STATE_DESCRIPTOR =
-                       new ListStateDescriptor<>("source-state", 
StringSerializer.INSTANCE);
-
-               static final String[] CHECKPOINTED_STRINGS = {
-                       "Here be dragons!",
-                       "Here be more dragons!",
-                       "Here be yet more dragons!",
-                       "Here be the mostest dragons!" };
-
-               private static final long serialVersionUID = 1L;
-
-               private volatile boolean isRunning = true;
-
-               private final int numElements;
-
-               private transient ListState<String> unionListState;
-
-               CheckpointingParallelSourceWithUnionListState(int numElements) {
-                       this.numElements = numElements;
-               }
-
-               @Override
-               public void snapshotState(FunctionSnapshotContext context) 
throws Exception {
-                       unionListState.clear();
-
-                       for (String s : CHECKPOINTED_STRINGS) {
-                               if (s.hashCode() % 
getRuntimeContext().getNumberOfParallelSubtasks() == 
getRuntimeContext().getIndexOfThisSubtask()) {
-                                       unionListState.add(s);
-                               }
-                       }
-               }
-
-               @Override
-               public void initializeState(FunctionInitializationContext 
context) throws Exception {
-                       unionListState = 
context.getOperatorStateStore().getUnionListState(
-                               STATE_DESCRIPTOR);
-               }
-
-               @Override
-               public void run(SourceContext<Tuple2<Long, Long>> ctx) throws 
Exception {
-
-                       ctx.emitWatermark(new Watermark(0));
-
-                       synchronized (ctx.getCheckpointLock()) {
-                               for (long i = 0; i < numElements; i++) {
-                                       if (i % 
getRuntimeContext().getNumberOfParallelSubtasks() == 
getRuntimeContext().getIndexOfThisSubtask()) {
-                                               ctx.collect(new Tuple2<>(i, i));
-                                       }
-                               }
-                       }
-
-                       // don't emit a final watermark so that we don't 
trigger the registered event-time
-                       // timers
-                       while (isRunning) {
-                               Thread.sleep(20);
-                       }
-               }
-
-               @Override
-               public void cancel() {
-                       isRunning = false;
-               }
-       }
-
-       private static class CheckingParallelSourceWithUnionListState
-               extends RichParallelSourceFunction<Tuple2<Long, Long>> 
implements CheckpointedFunction {
-
-               private static final long serialVersionUID = 1L;
-
-               static final String SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR = 
CheckingParallelSourceWithUnionListState.class + "_RESTORE_CHECK";
-
-               private volatile boolean isRunning = true;
-
-               private final int numElements;
-
-               CheckingParallelSourceWithUnionListState(int numElements) {
-                       this.numElements = numElements;
-               }
-
-               @Override
-               public void snapshotState(FunctionSnapshotContext context) 
throws Exception {
-
-               }
-
-               @Override
-               public void initializeState(FunctionInitializationContext 
context) throws Exception {
-                       ListState<String> unionListState = 
context.getOperatorStateStore().getUnionListState(
-                               
CheckpointingNonParallelSourceWithListState.STATE_DESCRIPTOR);
-
-                       if (context.isRestored()) {
-                               assertThat(unionListState.get(),
-                                       
containsInAnyOrder(CheckpointingParallelSourceWithUnionListState.CHECKPOINTED_STRINGS));
-
-                               
getRuntimeContext().addAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, new 
IntCounter());
-                               
getRuntimeContext().getAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR).add(1);
-                       } else {
-                               throw new RuntimeException(
-                                       "This source should always be restored 
because it's only used when restoring from a savepoint.");
-                       }
-               }
-
-               @Override
-               public void run(SourceContext<Tuple2<Long, Long>> ctx) throws 
Exception {
-
-                       // immediately trigger any set timers
-                       ctx.emitWatermark(new Watermark(1000));
-
-                       synchronized (ctx.getCheckpointLock()) {
-                               for (long i = 0; i < numElements; i++) {
-                                       if (i % 
getRuntimeContext().getNumberOfParallelSubtasks() == 
getRuntimeContext().getIndexOfThisSubtask()) {
-                                               ctx.collect(new Tuple2<>(i, i));
-                                       }
-                               }
-                       }
-
-                       while (isRunning) {
-                               Thread.sleep(20);
-                       }
-               }
-
-               @Override
-               public void cancel() {
-                       isRunning = false;
-               }
-       }
-
        private static class CheckpointingKeyedStateFlatMap extends 
RichFlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> {
 
                private static final long serialVersionUID = 1L;
@@ -609,26 +343,4 @@ public class StatefulJobSavepointMigrationITCase extends 
SavepointMigrationTestB
                        
getRuntimeContext().getAccumulator(SUCCESSFUL_PROCESSING_TIME_CHECK_ACCUMULATOR).add(1);
                }
        }
-
-       private static class AccumulatorCountingSink<T> extends 
RichSinkFunction<T> {
-               private static final long serialVersionUID = 1L;
-
-               static final String NUM_ELEMENTS_ACCUMULATOR = 
AccumulatorCountingSink.class + "_NUM_ELEMENTS";
-
-               int count = 0;
-
-               @Override
-               public void open(Configuration parameters) throws Exception {
-                       super.open(parameters);
-
-                       
getRuntimeContext().addAccumulator(NUM_ELEMENTS_ACCUMULATOR, new IntCounter());
-               }
-
-               @Override
-               public void invoke(T value, Context context) throws Exception {
-                       count++;
-                       
getRuntimeContext().getAccumulator(NUM_ELEMENTS_ACCUMULATOR).add(1);
-               }
-       }
-
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobWBroadcastStateMigrationITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobWBroadcastStateMigrationITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobWBroadcastStateMigrationITCase.java
new file mode 100644
index 0000000..147415d
--- /dev/null
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobWBroadcastStateMigrationITCase.java
@@ -0,0 +1,391 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.checkpointing.utils;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
+import org.apache.flink.runtime.state.StateBackendLoader;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.BroadcastStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.util.migration.MigrationVersion;
+import org.apache.flink.util.Collector;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Migration ITCases for a stateful job with broadcast state. The tests are 
parameterized to (potentially)
+ * cover migrating for multiple previous Flink versions, as well as for 
different state backends.
+ */
+@RunWith(Parameterized.class)
+public class StatefulJobWBroadcastStateMigrationITCase extends 
SavepointMigrationTestBase {
+
+       private static final int NUM_SOURCE_ELEMENTS = 4;
+
+       // TODO change this to PERFORM_SAVEPOINT to regenerate binary savepoints
+       private final StatefulJobSavepointMigrationITCase.ExecutionMode 
executionMode =
+                       
StatefulJobSavepointMigrationITCase.ExecutionMode.VERIFY_SAVEPOINT;
+
+       @Parameterized.Parameters(name = "Migrate Savepoint / Backend: {0}")
+       public static Collection<Tuple2<MigrationVersion, String>> parameters 
() {
+               return Arrays.asList(
+                               Tuple2.of(MigrationVersion.v1_5, 
StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
+                               Tuple2.of(MigrationVersion.v1_5, 
StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME));
+       }
+
+       private final MigrationVersion testMigrateVersion;
+       private final String testStateBackend;
+
+       public 
StatefulJobWBroadcastStateMigrationITCase(Tuple2<MigrationVersion, String> 
testMigrateVersionAndBackend) throws Exception {
+               this.testMigrateVersion = testMigrateVersionAndBackend.f0;
+               this.testStateBackend = testMigrateVersionAndBackend.f1;
+       }
+
+       @Test
+       public void testSavepoint() throws Exception {
+
+               final int parallelism = 4;
+
+               final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+               env.setRestartStrategy(RestartStrategies.noRestart());
+               env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime);
+
+               switch (testStateBackend) {
+                       case StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME:
+                               env.setStateBackend(new RocksDBStateBackend(new 
MemoryStateBackend()));
+                               break;
+                       case StateBackendLoader.MEMORY_STATE_BACKEND_NAME:
+                               env.setStateBackend(new MemoryStateBackend());
+                               break;
+                       default:
+                               throw new UnsupportedOperationException();
+               }
+
+               env.enableCheckpointing(500);
+               env.setParallelism(parallelism);
+               env.setMaxParallelism(parallelism);
+
+               SourceFunction<Tuple2<Long, Long>> nonParallelSource;
+               SourceFunction<Tuple2<Long, Long>> nonParallelSourceB;
+               SourceFunction<Tuple2<Long, Long>> parallelSource;
+               SourceFunction<Tuple2<Long, Long>> parallelSourceB;
+               KeyedBroadcastProcessFunction<Long, Tuple2<Long, Long>, 
Tuple2<Long, Long>, Tuple2<Long, Long>> firstBroadcastFunction;
+               KeyedBroadcastProcessFunction<Long, Tuple2<Long, Long>, 
Tuple2<Long, Long>, Tuple2<Long, Long>> secondBroadcastFunction;
+
+               final Map<Long, Long> expectedFirstState = new HashMap<>();
+               expectedFirstState.put(0L, 0L);
+               expectedFirstState.put(1L, 1L);
+               expectedFirstState.put(2L, 2L);
+               expectedFirstState.put(3L, 3L);
+
+               final Map<String, String> expectedSecondState = new HashMap<>();
+               expectedSecondState.put("0", "0");
+               expectedSecondState.put("1", "1");
+               expectedSecondState.put("2", "2");
+               expectedSecondState.put("3", "3");
+
+               final Map<String, String> expectedThirdState = new HashMap<>();
+               expectedThirdState.put("0", "0");
+               expectedThirdState.put("1", "1");
+               expectedThirdState.put("2", "2");
+               expectedThirdState.put("3", "3");
+
+               if (executionMode == 
StatefulJobSavepointMigrationITCase.ExecutionMode.PERFORM_SAVEPOINT) {
+                       nonParallelSource = new 
MigrationTestUtils.CheckpointingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS);
+                       nonParallelSourceB = new 
MigrationTestUtils.CheckpointingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS);
+                       parallelSource = new 
MigrationTestUtils.CheckpointingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS);
+                       parallelSourceB = new 
MigrationTestUtils.CheckpointingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS);
+                       firstBroadcastFunction = new 
CheckpointingKeyedBroadcastFunction();
+                       secondBroadcastFunction = new 
CheckpointingKeyedSingleBroadcastFunction();
+               } else if (executionMode == 
StatefulJobSavepointMigrationITCase.ExecutionMode.VERIFY_SAVEPOINT) {
+                       nonParallelSource = new 
MigrationTestUtils.CheckingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS);
+                       nonParallelSourceB = new 
MigrationTestUtils.CheckingNonParallelSourceWithListState(NUM_SOURCE_ELEMENTS);
+                       parallelSource = new 
MigrationTestUtils.CheckingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS);
+                       parallelSourceB = new 
MigrationTestUtils.CheckingParallelSourceWithUnionListState(NUM_SOURCE_ELEMENTS);
+                       firstBroadcastFunction = new 
CheckingKeyedBroadcastFunction(expectedFirstState, expectedSecondState);
+                       secondBroadcastFunction = new 
CheckingKeyedSingleBroadcastFunction(expectedThirdState);
+               } else {
+                       throw new IllegalStateException("Unknown ExecutionMode 
" + executionMode);
+               }
+
+               KeyedStream<Tuple2<Long, Long>, Long> npStream = env
+                               
.addSource(nonParallelSource).uid("CheckpointingSource1")
+                               .keyBy(new KeySelector<Tuple2<Long, Long>, 
Long>() {
+
+                                       private static final long 
serialVersionUID = -4514793867774977152L;
+
+                                       @Override
+                                       public Long getKey(Tuple2<Long, Long> 
value) throws Exception {
+                                               return value.f0;
+                                       }
+                               });
+
+               KeyedStream<Tuple2<Long, Long>, Long> pStream = env
+                               
.addSource(parallelSource).uid("CheckpointingSource2")
+                               .keyBy(new KeySelector<Tuple2<Long, Long>, 
Long>() {
+
+                                       private static final long 
serialVersionUID = 4940496713319948104L;
+
+                                       @Override
+                                       public Long getKey(Tuple2<Long, Long> 
value) throws Exception {
+                                               return value.f0;
+                                       }
+                               });
+
+               final MapStateDescriptor<Long, Long> firstBroadcastStateDesc = 
new MapStateDescriptor<>(
+                               "broadcast-state-1", 
BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO
+               );
+
+               final MapStateDescriptor<String, String> 
secondBroadcastStateDesc = new MapStateDescriptor<>(
+                               "broadcast-state-2", 
BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO
+               );
+
+               final MapStateDescriptor<String, String> 
thirdBroadcastStateDesc = new MapStateDescriptor<>(
+                               "broadcast-state-3", 
BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO
+               );
+
+               BroadcastStream<Tuple2<Long, Long>> npBroadcastStream = env
+                               
.addSource(nonParallelSourceB).uid("BrCheckpointingSource1")
+                               .broadcast(firstBroadcastStateDesc, 
secondBroadcastStateDesc);
+
+               BroadcastStream<Tuple2<Long, Long>> pBroadcastStream = env
+                               
.addSource(parallelSourceB).uid("BrCheckpointingSource2")
+                               .broadcast(thirdBroadcastStateDesc);
+
+               npStream
+                               .connect(npBroadcastStream)
+                               
.process(firstBroadcastFunction).uid("BrProcess1")
+                               .addSink(new 
MigrationTestUtils.AccumulatorCountingSink<>());
+
+               pStream
+                               .connect(pBroadcastStream)
+                               
.process(secondBroadcastFunction).uid("BrProcess2")
+                               .addSink(new 
MigrationTestUtils.AccumulatorCountingSink<>());
+
+               if (executionMode == 
StatefulJobSavepointMigrationITCase.ExecutionMode.PERFORM_SAVEPOINT) {
+                       executeAndSavepoint(
+                                       env,
+                                       "src/test/resources/" + 
getBroadcastSavepointPath(testMigrateVersion, testStateBackend),
+                                       new 
Tuple2<>(MigrationTestUtils.AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, 2 
* NUM_SOURCE_ELEMENTS));
+               } else {
+                       restoreAndExecute(
+                                       env,
+                                       
getResourceFilename(getBroadcastSavepointPath(testMigrateVersion, 
testStateBackend)),
+                                       new 
Tuple2<>(MigrationTestUtils.CheckingNonParallelSourceWithListState.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR,
 2),                                        // we have 2 sources
+                                       new 
Tuple2<>(MigrationTestUtils.CheckingParallelSourceWithUnionListState.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR,
 2 * parallelism),        // we have 2 sources
+                                       new 
Tuple2<>(MigrationTestUtils.AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, 
NUM_SOURCE_ELEMENTS * 2)
+                       );
+               }
+       }
+
+       private String getBroadcastSavepointPath(MigrationVersion 
savepointVersion, String backendType) {
+               switch (backendType) {
+                       case StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME:
+                               return 
"new-stateful-broadcast-udf-migration-itcase-flink" + savepointVersion + 
"-rocksdb-savepoint";
+                       case StateBackendLoader.MEMORY_STATE_BACKEND_NAME:
+                               return 
"new-stateful-broadcast-udf-migration-itcase-flink" + savepointVersion + 
"-savepoint";
+                       default:
+                               throw new UnsupportedOperationException();
+               }
+       }
+
+       /**
+        * A simple {@link KeyedBroadcastProcessFunction} that puts everything 
on the broadcast side in the state.
+        */
+       private static class CheckpointingKeyedBroadcastFunction
+                       extends KeyedBroadcastProcessFunction<Long, 
Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
+
+               private static final long serialVersionUID = 
1333992081671604521L;
+
+               private MapStateDescriptor<Long, Long> firstStateDesc;
+
+               private MapStateDescriptor<String, String> secondStateDesc;
+
+               @Override
+               public void open(Configuration parameters) throws Exception {
+                       super.open(parameters);
+
+                       firstStateDesc = new MapStateDescriptor<>(
+                                       "broadcast-state-1", 
BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO
+                       );
+
+                       secondStateDesc = new MapStateDescriptor<>(
+                                       "broadcast-state-2", 
BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO
+                       );
+               }
+
+               @Override
+               public void processElement(Tuple2<Long, Long> value, 
ReadOnlyContext ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
+                       out.collect(value);
+               }
+
+               @Override
+               public void processBroadcastElement(Tuple2<Long, Long> value, 
Context ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
+                       ctx.getBroadcastState(firstStateDesc).put(value.f0, 
value.f1);
+                       
ctx.getBroadcastState(secondStateDesc).put(Long.toString(value.f0), 
Long.toString(value.f1));
+               }
+       }
+
+       /**
+        * A simple {@link KeyedBroadcastProcessFunction} that puts everything 
on the broadcast side in the state.
+        */
+       private static class CheckpointingKeyedSingleBroadcastFunction
+                       extends KeyedBroadcastProcessFunction<Long, 
Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
+
+               private static final long serialVersionUID = 
1333992081671604521L;
+
+               private MapStateDescriptor<String, String> stateDesc;
+
+               @Override
+               public void open(Configuration parameters) throws Exception {
+                       super.open(parameters);
+
+                       stateDesc = new MapStateDescriptor<>(
+                                       "broadcast-state-3", 
BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO
+                       );
+               }
+
+               @Override
+               public void processElement(Tuple2<Long, Long> value, 
ReadOnlyContext ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
+                       out.collect(value);
+               }
+
+               @Override
+               public void processBroadcastElement(Tuple2<Long, Long> value, 
Context ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
+                       
ctx.getBroadcastState(stateDesc).put(Long.toString(value.f0), 
Long.toString(value.f1));
+               }
+       }
+
+       /**
+        * A simple {@link KeyedBroadcastProcessFunction} that verifies the 
contents of the broadcast state after recovery.
+        */
+       private static class CheckingKeyedBroadcastFunction
+                       extends KeyedBroadcastProcessFunction<Long, 
Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
+
+               private static final long serialVersionUID = 
1333992081671604521L;
+
+               private final Map<Long, Long> expectedFirstState;
+
+               private final Map<String, String> expectedSecondState;
+
+               private MapStateDescriptor<Long, Long> firstStateDesc;
+
+               private MapStateDescriptor<String, String> secondStateDesc;
+
+               CheckingKeyedBroadcastFunction(Map<Long, Long> firstState, 
Map<String, String> secondState) {
+                       this.expectedFirstState = firstState;
+                       this.expectedSecondState = secondState;
+               }
+
+               @Override
+               public void open(Configuration parameters) throws Exception {
+                       super.open(parameters);
+
+                       firstStateDesc = new MapStateDescriptor<>(
+                                       "broadcast-state-1", 
BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO
+                       );
+
+                       secondStateDesc = new MapStateDescriptor<>(
+                                       "broadcast-state-2", 
BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO
+                       );
+               }
+
+               @Override
+               public void processElement(Tuple2<Long, Long> value, 
ReadOnlyContext ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
+
+                       final Map<Long, Long> actualFirstState = new 
HashMap<>();
+                       for (Map.Entry<Long, Long> entry: 
ctx.getBroadcastState(firstStateDesc).immutableEntries()) {
+                               actualFirstState.put(entry.getKey(), 
entry.getValue());
+                       }
+                       Assert.assertEquals(expectedFirstState, 
actualFirstState);
+
+                       final Map<String, String> actualSecondState = new 
HashMap<>();
+                       for (Map.Entry<String, String> entry: 
ctx.getBroadcastState(secondStateDesc).immutableEntries()) {
+                               actualSecondState.put(entry.getKey(), 
entry.getValue());
+                       }
+                       Assert.assertEquals(expectedSecondState, 
actualSecondState);
+
+                       out.collect(value);
+               }
+
+               @Override
+               public void processBroadcastElement(Tuple2<Long, Long> value, 
Context ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
+                       // now we do nothing as we just want to verify the 
contents of the broadcast state.
+               }
+       }
+
+       /**
+        * A simple {@link KeyedBroadcastProcessFunction} that verifies the 
contents of the broadcast state after recovery.
+        */
+       private static class CheckingKeyedSingleBroadcastFunction
+                       extends KeyedBroadcastProcessFunction<Long, 
Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
+
+               private static final long serialVersionUID = 
1333992081671604521L;
+
+               private final Map<String, String> expectedState;
+
+               private MapStateDescriptor<String, String> stateDesc;
+
+               CheckingKeyedSingleBroadcastFunction(Map<String, String> state) 
{
+                       this.expectedState = state;
+               }
+
+               @Override
+               public void open(Configuration parameters) throws Exception {
+                       super.open(parameters);
+
+                       stateDesc = new MapStateDescriptor<>(
+                                       "broadcast-state-3", 
BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO
+                       );
+               }
+
+               @Override
+               public void processElement(Tuple2<Long, Long> value, 
ReadOnlyContext ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
+                       final Map<String, String> actualState = new HashMap<>();
+                       for (Map.Entry<String, String> entry: 
ctx.getBroadcastState(stateDesc).immutableEntries()) {
+                               actualState.put(entry.getKey(), 
entry.getValue());
+                       }
+                       Assert.assertEquals(expectedState, actualState);
+
+                       out.collect(value);
+               }
+
+               @Override
+               public void processBroadcastElement(Tuple2<Long, Long> value, 
Context ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
+                       // now we do nothing as we just want to verify the 
contents of the broadcast state.
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata
 
b/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata
new file mode 100644
index 0000000..aed6183
Binary files /dev/null and 
b/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata
 differ

http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-savepoint/_metadata
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-savepoint/_metadata
 
b/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-savepoint/_metadata
new file mode 100644
index 0000000..13e8d32
Binary files /dev/null and 
b/flink-tests/src/test/resources/new-stateful-broadcast-udf-migration-itcase-flink1.5-savepoint/_metadata
 differ

http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-jobmanager-savepoint/_metadata
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-jobmanager-savepoint/_metadata
 
b/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-jobmanager-savepoint/_metadata
new file mode 100644
index 0000000..f845ee3
Binary files /dev/null and 
b/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-jobmanager-savepoint/_metadata
 differ

http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata
 
b/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata
new file mode 100644
index 0000000..6b05ef9
Binary files /dev/null and 
b/flink-tests/src/test/resources/stateful-scala-with-broadcast2.11-udf-migration-itcase-flink1.5-rocksdb-savepoint/_metadata
 differ

http://git-wip-us.apache.org/repos/asf/flink/blob/0c06852b/flink-tests/src/test/scala/org/apache/flink/api/scala/migration/StatefulJobWBroadcastStateMigrationITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/scala/org/apache/flink/api/scala/migration/StatefulJobWBroadcastStateMigrationITCase.scala
 
b/flink-tests/src/test/scala/org/apache/flink/api/scala/migration/StatefulJobWBroadcastStateMigrationITCase.scala
new file mode 100644
index 0000000..0ed705b
--- /dev/null
+++ 
b/flink-tests/src/test/scala/org/apache/flink/api/scala/migration/StatefulJobWBroadcastStateMigrationITCase.scala
@@ -0,0 +1,450 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.migration
+
+import java.util
+
+import org.apache.flink.api.common.accumulators.IntCounter
+import org.apache.flink.api.common.functions.RichFlatMapFunction
+import org.apache.flink.api.common.state._
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.functions.KeySelector
+import org.apache.flink.api.java.tuple.Tuple2
+import org.apache.flink.api.scala.createTypeInformation
+import org.apache.flink.api.scala.migration.CustomEnum.CustomEnum
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.contrib.streaming.state.RocksDBStateBackend
+import org.apache.flink.runtime.state.{FunctionInitializationContext, 
FunctionSnapshotContext, StateBackendLoader}
+import org.apache.flink.runtime.state.memory.MemoryStateBackend
+import org.apache.flink.streaming.api.TimeCharacteristic
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
+import 
org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction
+import org.apache.flink.streaming.api.functions.sink.RichSinkFunction
+import org.apache.flink.streaming.api.functions.source.SourceFunction
+import org.apache.flink.streaming.api.watermark.Watermark
+import org.apache.flink.streaming.util.migration.MigrationVersion
+import org.apache.flink.test.checkpointing.utils.SavepointMigrationTestBase
+import org.apache.flink.util.Collector
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+import org.junit.{Assert, Ignore, Test}
+
+import scala.util.{Failure, Properties, Try}
+
+object StatefulJobWBroadcastStateMigrationITCase {
+
+  @Parameterized.Parameters(name = "Migrate Savepoint / Backend: {0}")
+  def parameters: util.Collection[(MigrationVersion, String)] = {
+    util.Arrays.asList(
+      (MigrationVersion.v1_5, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
+      (MigrationVersion.v1_5, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME))
+  }
+
+  // TODO to generate savepoints for a specific Flink version / backend type,
+  // TODO change these values accordingly, e.g. to generate for 1.3 with 
RocksDB,
+  // TODO set as (MigrationVersion.v1_3, 
StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME)
+  val GENERATE_SAVEPOINT_VER: MigrationVersion = MigrationVersion.v1_5
+  val GENERATE_SAVEPOINT_BACKEND_TYPE: String = 
StateBackendLoader.MEMORY_STATE_BACKEND_NAME
+
+  val SCALA_VERSION: String = {
+    val versionString = Properties.versionString.split(" ")(1)
+    versionString.substring(0, versionString.lastIndexOf("."))
+  }
+
+  val NUM_ELEMENTS = 4
+}
+
+/**
+  * ITCase for migration Scala state types across different Flink versions.
+  */
+@RunWith(classOf[Parameterized])
+class StatefulJobWBroadcastStateMigrationITCase(
+                                        migrationVersionAndBackend: 
(MigrationVersion, String))
+  extends SavepointMigrationTestBase with Serializable {
+
+  @Test
+  @Ignore
+  def testCreateSavepointWithBroadcastState(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+
+    StatefulJobWBroadcastStateMigrationITCase.GENERATE_SAVEPOINT_BACKEND_TYPE 
match {
+      case StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME =>
+        env.setStateBackend(new RocksDBStateBackend(new MemoryStateBackend()))
+      case StateBackendLoader.MEMORY_STATE_BACKEND_NAME =>
+        env.setStateBackend(new MemoryStateBackend())
+      case _ => throw new UnsupportedOperationException
+    }
+
+    lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long](
+      "broadcast-state-1",
+      BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
+      BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]])
+
+    lazy val secondBroadcastStateDesc = new MapStateDescriptor[String, String](
+      "broadcast-state-2",
+      BasicTypeInfo.STRING_TYPE_INFO,
+      BasicTypeInfo.STRING_TYPE_INFO)
+
+    env.setStateBackend(new MemoryStateBackend)
+    env.enableCheckpointing(500)
+    env.setParallelism(4)
+    env.setMaxParallelism(4)
+
+    val stream = env
+      .addSource(
+        new 
CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedSource")
+      .keyBy(
+        new KeySelector[(Long, Long), Long] {
+          override def getKey(value: (Long, Long)): Long = value._1
+        }
+      )
+      .flatMap(new StatefulFlatMapper)
+      .keyBy(
+        new KeySelector[(Long, Long), Long] {
+          override def getKey(value: (Long, Long)): Long = value._1
+        }
+      )
+
+    val broadcastStream = env
+      .addSource(
+        new 
CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedBroadcastSource")
+      .broadcast(firstBroadcastStateDesc, secondBroadcastStateDesc)
+
+    stream
+      .connect(broadcastStream)
+      .process(new TestBroadcastProcessFunction)
+      .addSink(new AccumulatorCountingSink)
+
+    executeAndSavepoint(
+      env,
+      s"src/test/resources/stateful-scala-with-broadcast" +
+        s"${StatefulJobWBroadcastStateMigrationITCase.SCALA_VERSION}" +
+        s"-udf-migration-itcase-flink" +
+        s"${StatefulJobWBroadcastStateMigrationITCase.GENERATE_SAVEPOINT_VER}" 
+
+        
s"-${StatefulJobWBroadcastStateMigrationITCase.GENERATE_SAVEPOINT_BACKEND_TYPE}-savepoint",
+      new Tuple2(
+        AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR,
+        StatefulJobWBroadcastStateMigrationITCase.NUM_ELEMENTS
+      )
+    )
+  }
+
+  @Test
+  def testRestoreSavepointWithBroadcast(): Unit = {
+
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+
+    migrationVersionAndBackend._2 match {
+      case StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME =>
+        env.setStateBackend(new RocksDBStateBackend(new MemoryStateBackend()))
+      case StateBackendLoader.MEMORY_STATE_BACKEND_NAME =>
+        env.setStateBackend(new MemoryStateBackend())
+      case _ => throw new UnsupportedOperationException
+    }
+
+    lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long](
+      "broadcast-state-1",
+      BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
+      BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]])
+
+    lazy val secondBroadcastStateDesc = new MapStateDescriptor[String, String](
+      "broadcast-state-2",
+      BasicTypeInfo.STRING_TYPE_INFO,
+      BasicTypeInfo.STRING_TYPE_INFO)
+
+    env.setStateBackend(new MemoryStateBackend)
+    env.enableCheckpointing(500)
+    env.setParallelism(4)
+    env.setMaxParallelism(4)
+
+    val stream = env
+      .addSource(
+        new 
CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedSource")
+      .keyBy(
+        new KeySelector[(Long, Long), Long] {
+          override def getKey(value: (Long, Long)): Long = value._1
+        }
+      )
+      .flatMap(new StatefulFlatMapper)
+      .keyBy(
+        new KeySelector[(Long, Long), Long] {
+          override def getKey(value: (Long, Long)): Long = value._1
+        }
+      )
+
+    val broadcastStream = env
+      .addSource(
+        new 
CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedBroadcastSource")
+      .broadcast(firstBroadcastStateDesc, secondBroadcastStateDesc)
+
+    val expectedFirstState: Map[Long, Long] =
+      Map(0L -> 0L, 1L -> 1L, 2L -> 2L, 3L -> 3L)
+    val expectedSecondState: Map[String, String] =
+      Map("0" -> "0", "1" -> "1", "2" -> "2", "3" -> "3")
+
+    stream
+      .connect(broadcastStream)
+      .process(new VerifyingBroadcastProcessFunction(expectedFirstState, 
expectedSecondState))
+      .addSink(new AccumulatorCountingSink)
+
+    restoreAndExecute(
+      env,
+      SavepointMigrationTestBase.getResourceFilename(
+        
s"stateful-scala-with-broadcast${StatefulJobWBroadcastStateMigrationITCase.SCALA_VERSION}"
 +
+          s"-udf-migration-itcase-flink${migrationVersionAndBackend._1}" +
+          s"-${migrationVersionAndBackend._2}-savepoint"),
+      new Tuple2(
+        AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR,
+        StatefulJobWBroadcastStateMigrationITCase.NUM_ELEMENTS)
+    )
+  }
+}
+
+class TestBroadcastProcessFunction
+  extends KeyedBroadcastProcessFunction
+    [Long, (Long, Long), (Long, Long), (Long, Long)] {
+
+  lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long](
+    "broadcast-state-1",
+    BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
+    BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]])
+
+  val secondBroadcastStateDesc = new MapStateDescriptor[String, String](
+    "broadcast-state-2",
+    BasicTypeInfo.STRING_TYPE_INFO,
+    BasicTypeInfo.STRING_TYPE_INFO)
+
+  @throws[Exception]
+  override def processElement(
+                               value: (Long, Long),
+                               ctx: KeyedBroadcastProcessFunction
+                                 [Long, (Long, Long), (Long, Long), (Long, 
Long)]#ReadOnlyContext,
+                               out: Collector[(Long, Long)]): Unit = {
+
+    out.collect(value)
+  }
+
+  @throws[Exception]
+  override def processBroadcastElement(
+                                        value: (Long, Long),
+                                        ctx: KeyedBroadcastProcessFunction
+                                          [Long, (Long, Long), (Long, Long), 
(Long, Long)]#Context,
+                                        out: Collector[(Long, Long)]): Unit = {
+
+    ctx.getBroadcastState(firstBroadcastStateDesc).put(value._1, value._2)
+    ctx.getBroadcastState(secondBroadcastStateDesc).put(value._1.toString, 
value._2.toString)
+  }
+}
+
+@SerialVersionUID(1L)
+private object CheckpointedSource {
+  var CHECKPOINTED_STRING = "Here be dragons!"
+}
+
+@SerialVersionUID(1L)
+private class CheckpointedSource(val numElements: Int)
+  extends SourceFunction[(Long, Long)] with CheckpointedFunction {
+
+  private var isRunning = true
+  private var state: ListState[CustomCaseClass] = _
+
+  @throws[Exception]
+  override def run(ctx: SourceFunction.SourceContext[(Long, Long)]) {
+    ctx.emitWatermark(new Watermark(0))
+    ctx.getCheckpointLock synchronized {
+      var i = 0
+      while (i < numElements) {
+        ctx.collect(i, i)
+        i += 1
+      }
+    }
+    // don't emit a final watermark so that we don't trigger the registered 
event-time
+    // timers
+    while (isRunning) Thread.sleep(20)
+  }
+
+  def cancel() {
+    isRunning = false
+  }
+
+  override def initializeState(context: FunctionInitializationContext): Unit = 
{
+    state = context.getOperatorStateStore.getOperatorState(
+      new ListStateDescriptor[CustomCaseClass](
+        "sourceState", createTypeInformation[CustomCaseClass]))
+  }
+
+  override def snapshotState(context: FunctionSnapshotContext): Unit = {
+    state.clear()
+    state.add(CustomCaseClass("Here be dragons!", 123))
+  }
+}
+
+@SerialVersionUID(1L)
+private object AccumulatorCountingSink {
+  var NUM_ELEMENTS_ACCUMULATOR = classOf[AccumulatorCountingSink[_]] + 
"_NUM_ELEMENTS"
+}
+
+@SerialVersionUID(1L)
+private class AccumulatorCountingSink[T] extends RichSinkFunction[T] {
+
+  private var count: Int = 0
+
+  @throws[Exception]
+  override def open(parameters: Configuration) {
+    super.open(parameters)
+    getRuntimeContext.addAccumulator(
+      AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, new IntCounter)
+  }
+
+  @throws[Exception]
+  def invoke(value: T) {
+    count += 1
+    getRuntimeContext.getAccumulator(
+      AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR).add(1)
+  }
+}
+
+class StatefulFlatMapper extends RichFlatMapFunction[(Long, Long), (Long, 
Long)] {
+
+  private var caseClassState: ValueState[CustomCaseClass] = _
+  private var caseClassWithNestingState: 
ValueState[CustomCaseClassWithNesting] = _
+  private var collectionState: ValueState[List[CustomCaseClass]] = _
+  private var tryState: ValueState[Try[CustomCaseClass]] = _
+  private var tryFailureState: ValueState[Try[CustomCaseClass]] = _
+  private var optionState: ValueState[Option[CustomCaseClass]] = _
+  private var optionNoneState: ValueState[Option[CustomCaseClass]] = _
+  private var eitherLeftState: ValueState[Either[CustomCaseClass, String]] = _
+  private var eitherRightState: ValueState[Either[CustomCaseClass, String]] = _
+  private var enumOneState: ValueState[CustomEnum] = _
+  private var enumThreeState: ValueState[CustomEnum] = _
+
+  override def open(parameters: Configuration): Unit = {
+    caseClassState = getRuntimeContext.getState(
+      new ValueStateDescriptor[CustomCaseClass](
+        "caseClassState", createTypeInformation[CustomCaseClass]))
+    caseClassWithNestingState = getRuntimeContext.getState(
+      new ValueStateDescriptor[CustomCaseClassWithNesting](
+        "caseClassWithNestingState", 
createTypeInformation[CustomCaseClassWithNesting]))
+    collectionState = getRuntimeContext.getState(
+      new ValueStateDescriptor[List[CustomCaseClass]](
+        "collectionState", createTypeInformation[List[CustomCaseClass]]))
+    tryState = getRuntimeContext.getState(
+      new ValueStateDescriptor[Try[CustomCaseClass]](
+        "tryState", createTypeInformation[Try[CustomCaseClass]]))
+    tryFailureState = getRuntimeContext.getState(
+      new ValueStateDescriptor[Try[CustomCaseClass]](
+        "tryFailureState", createTypeInformation[Try[CustomCaseClass]]))
+    optionState = getRuntimeContext.getState(
+      new ValueStateDescriptor[Option[CustomCaseClass]](
+        "optionState", createTypeInformation[Option[CustomCaseClass]]))
+    optionNoneState = getRuntimeContext.getState(
+      new ValueStateDescriptor[Option[CustomCaseClass]](
+        "optionNoneState", createTypeInformation[Option[CustomCaseClass]]))
+    eitherLeftState = getRuntimeContext.getState(
+      new ValueStateDescriptor[Either[CustomCaseClass, String]](
+        "eitherLeftState", createTypeInformation[Either[CustomCaseClass, 
String]]))
+    eitherRightState = getRuntimeContext.getState(
+      new ValueStateDescriptor[Either[CustomCaseClass, String]](
+        "eitherRightState", createTypeInformation[Either[CustomCaseClass, 
String]]))
+    enumOneState = getRuntimeContext.getState(
+      new ValueStateDescriptor[CustomEnum](
+        "enumOneState", createTypeInformation[CustomEnum]))
+    enumThreeState = getRuntimeContext.getState(
+      new ValueStateDescriptor[CustomEnum](
+        "enumThreeState", createTypeInformation[CustomEnum]))
+  }
+
+  override def flatMap(in: (Long, Long), collector: Collector[(Long, Long)]): 
Unit = {
+    caseClassState.update(CustomCaseClass(in._1.toString, in._2 * 2))
+    caseClassWithNestingState.update(
+      CustomCaseClassWithNesting(in._1, CustomCaseClass(in._1.toString, in._2 
* 2)))
+    collectionState.update(List(CustomCaseClass(in._1.toString, in._2 * 2)))
+    tryState.update(Try(CustomCaseClass(in._1.toString, in._2 * 5)))
+    tryFailureState.update(Failure(new RuntimeException))
+    optionState.update(Some(CustomCaseClass(in._1.toString, in._2 * 2)))
+    optionNoneState.update(None)
+    eitherLeftState.update(Left(CustomCaseClass(in._1.toString, in._2 * 2)))
+    eitherRightState.update(Right((in._1 * 3).toString))
+    enumOneState.update(CustomEnum.ONE)
+    enumOneState.update(CustomEnum.THREE)
+
+    collector.collect(in)
+  }
+}
+
+class VerifyingBroadcastProcessFunction(
+                                         firstExpectedBroadcastState: 
Map[Long, Long],
+                                         secondExpectedBroadcastState: 
Map[String, String])
+  extends KeyedBroadcastProcessFunction
+    [Long, (Long, Long), (Long, Long), (Long, Long)] {
+
+  lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long](
+    "broadcast-state-1",
+    BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
+    BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]])
+
+  val secondBroadcastStateDesc = new MapStateDescriptor[String, String](
+    "broadcast-state-2",
+    BasicTypeInfo.STRING_TYPE_INFO,
+    BasicTypeInfo.STRING_TYPE_INFO)
+
+  @throws[Exception]
+  override def processElement(
+                               value: (Long, Long),
+                               ctx: KeyedBroadcastProcessFunction
+                                 [Long, (Long, Long), (Long, Long), (Long, 
Long)]#ReadOnlyContext,
+                               out: Collector[(Long, Long)]): Unit = {
+
+    var actualFirstState = Map[Long, Long]()
+
+    import scala.collection.JavaConversions._
+    for (entry <- 
ctx.getBroadcastState(firstBroadcastStateDesc).immutableEntries()) {
+      val v = firstExpectedBroadcastState.get(entry.getKey).get
+      Assert.assertEquals(v, entry.getValue)
+      actualFirstState += (entry.getKey -> entry.getValue)
+    }
+
+    Assert.assertEquals(firstExpectedBroadcastState, actualFirstState)
+
+    var actualSecondState = Map[String, String]()
+
+    import scala.collection.JavaConversions._
+    for (entry <- 
ctx.getBroadcastState(secondBroadcastStateDesc).immutableEntries()) {
+      val v = secondExpectedBroadcastState.get(entry.getKey).get
+      Assert.assertEquals(v, entry.getValue)
+      actualSecondState += (entry.getKey -> entry.getValue)
+    }
+
+    Assert.assertEquals(secondExpectedBroadcastState, actualSecondState)
+    out.collect(value)
+  }
+
+  @throws[Exception]
+  override def processBroadcastElement(
+                                        value: (Long, Long),
+                                        ctx: KeyedBroadcastProcessFunction
+                                          [Long, (Long, Long), (Long, Long), 
(Long, Long)]#Context,
+                                        out: Collector[(Long, Long)]): Unit = {
+
+    // do nothing
+  }
+}

Reply via email to