http://git-wip-us.apache.org/repos/asf/flink/blob/484fedd4/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index d318575..c791fd8 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -2731,10 +2731,15 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
        @Test
        public void testReplicateModeStateHandle() {
                Map<String, OperatorStateHandle.StateMetaInfo> metaInfoMap = 
new HashMap<>(1);
-               metaInfoMap.put("t-1", new 
OperatorStateHandle.StateMetaInfo(new long[]{0, 23}, 
OperatorStateHandle.Mode.BROADCAST));
-               metaInfoMap.put("t-2", new 
OperatorStateHandle.StateMetaInfo(new long[]{42, 64}, 
OperatorStateHandle.Mode.BROADCAST));
+               metaInfoMap.put("t-1", new 
OperatorStateHandle.StateMetaInfo(new long[]{0, 23}, 
OperatorStateHandle.Mode.UNION));
+               metaInfoMap.put("t-2", new 
OperatorStateHandle.StateMetaInfo(new long[]{42, 64}, 
OperatorStateHandle.Mode.UNION));
                metaInfoMap.put("t-3", new 
OperatorStateHandle.StateMetaInfo(new long[]{72, 83}, 
OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
-               OperatorStateHandle osh = new OperatorStateHandle(metaInfoMap, 
new ByteStreamStateHandle("test", new byte[100]));
+               metaInfoMap.put("t-4", new 
OperatorStateHandle.StateMetaInfo(new long[]{87, 94, 95}, 
OperatorStateHandle.Mode.BROADCAST));
+               metaInfoMap.put("t-5", new 
OperatorStateHandle.StateMetaInfo(new long[]{97, 108, 112}, 
OperatorStateHandle.Mode.BROADCAST));
+               metaInfoMap.put("t-6", new 
OperatorStateHandle.StateMetaInfo(new long[]{121, 143, 147}, 
OperatorStateHandle.Mode.BROADCAST));
+
+               // this is what a single task will return
+               OperatorStateHandle osh = new OperatorStateHandle(metaInfoMap, 
new ByteStreamStateHandle("test", new byte[150]));
 
                OperatorStateRepartitioner repartitioner = 
RoundRobinOperatorStateRepartitioner.INSTANCE;
                List<Collection<OperatorStateHandle>> repartitionedStates =
@@ -2757,18 +2762,26 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
 
                                        OperatorStateHandle.StateMetaInfo 
stateMetaInfo = stateNameToMetaInfo.getValue();
                                        if 
(OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.equals(stateMetaInfo.getDistributionMode()))
 {
+                                               // SPLIT_DISTRIBUTE: so split 
the state and re-distribute it -> each one will go to one task
                                                Assert.assertEquals(1, 
stateNameToMetaInfo.getValue().getOffsets().length);
-                                       } else {
+                                       } else if 
(OperatorStateHandle.Mode.UNION.equals(stateMetaInfo.getDistributionMode())) {
+                                               // BROADCAST: so all to all
                                                Assert.assertEquals(2, 
stateNameToMetaInfo.getValue().getOffsets().length);
+                                       } else {
+                                               // UNIFORM_BROADCAST: so all to 
all
+                                               Assert.assertEquals(3, 
stateNameToMetaInfo.getValue().getOffsets().length);
                                        }
                                }
                        }
                }
 
-               Assert.assertEquals(3, checkCounts.size());
+               Assert.assertEquals(6, checkCounts.size());
                Assert.assertEquals(3, checkCounts.get("t-1").intValue());
                Assert.assertEquals(3, checkCounts.get("t-2").intValue());
                Assert.assertEquals(2, checkCounts.get("t-3").intValue());
+               Assert.assertEquals(3, checkCounts.get("t-4").intValue());
+               Assert.assertEquals(3, checkCounts.get("t-5").intValue());
+               Assert.assertEquals(3, checkCounts.get("t-6").intValue());
        }
 
        // 
------------------------------------------------------------------------
@@ -3243,7 +3256,7 @@ public class CheckpointCoordinatorTest extends TestLogger 
{
                        Path fakePath = new Path("/fake-" + i);
                        Map<String, OperatorStateHandle.StateMetaInfo> 
namedStatesToOffsets = new HashMap<>();
                        int off = 0;
-                       for (int s = 0; s < numNamedStates; ++s) {
+                       for (int s = 0; s < numNamedStates - 1; ++s) {
                                long[] offs = new long[1 + 
r.nextInt(maxPartitionsPerState)];
 
                                for (int o = 0; o < offs.length; ++o) {
@@ -3252,19 +3265,29 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                                }
 
                                OperatorStateHandle.Mode mode = r.nextInt(10) 
== 0 ?
-                                               
OperatorStateHandle.Mode.BROADCAST : OperatorStateHandle.Mode.SPLIT_DISTRIBUTE;
+                                               OperatorStateHandle.Mode.UNION 
: OperatorStateHandle.Mode.SPLIT_DISTRIBUTE;
                                namedStatesToOffsets.put(
                                                "State-" + s,
                                                new 
OperatorStateHandle.StateMetaInfo(offs, mode));
 
                        }
 
+                       if (numNamedStates % 2 == 0) {
+                               // finally add a broadcast state
+                               long[] offs = {off + 1, off + 2, off + 3, off + 
4};
+
+                               namedStatesToOffsets.put(
+                                               "State-" + (numNamedStates - 1),
+                                               new 
OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.BROADCAST));
+                       }
+
                        previousParallelOpInstanceStates.add(
                                        new 
OperatorStateHandle(namedStatesToOffsets, new FileStateHandle(fakePath, -1)));
                }
 
                Map<StreamStateHandle, Map<String, List<Long>>> expected = new 
HashMap<>();
 
+               int taskIndex = 0;
                int expectedTotalPartitions = 0;
                for (OperatorStateHandle psh : 
previousParallelOpInstanceStates) {
                        Map<String, OperatorStateHandle.StateMetaInfo> offsMap 
= psh.getStateNameToPartitionOffsets();
@@ -3272,20 +3295,39 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                        for (Map.Entry<String, 
OperatorStateHandle.StateMetaInfo> e : offsMap.entrySet()) {
 
                                long[] offs = e.getValue().getOffsets();
-                               int replication = 
e.getValue().getDistributionMode().equals(OperatorStateHandle.Mode.BROADCAST) ?
-                                               newParallelism : 1;
+                               int replication;
+                               switch (e.getValue().getDistributionMode()) {
+                                       case UNION:
+                                               replication = newParallelism;
+                                               break;
+                                       case BROADCAST:
+                                               int extra = taskIndex < 
(newParallelism % oldParallelism) ? 1 : 0;
+                                               replication = newParallelism / 
oldParallelism + extra;
+                                               break;
+                                       case SPLIT_DISTRIBUTE:
+                                               replication = 1;
+                                               break;
+                                       default:
+                                               throw new 
RuntimeException("Unknown distribution mode " + 
e.getValue().getDistributionMode());
+                               }
 
-                               expectedTotalPartitions += replication * 
offs.length;
-                               List<Long> offsList = new 
ArrayList<>(offs.length);
+                               if (replication > 0) {
+                                       expectedTotalPartitions += replication 
* offs.length;
+                                       List<Long> offsList = new 
ArrayList<>(offs.length);
 
-                               for (long off : offs) {
-                                       for (int p = 0; p < replication; ++p) {
-                                               offsList.add(off);
+                                       for (long off : offs) {
+                                               for (int p = 0; p < 
replication; ++p) {
+                                                       offsList.add(off);
+                                               }
                                        }
+                                       offsMapWithList.put(e.getKey(), 
offsList);
                                }
-                               offsMapWithList.put(e.getKey(), offsList);
                        }
-                       expected.put(psh.getDelegateStateHandle(), 
offsMapWithList);
+
+                       if (!offsMapWithList.isEmpty()) {
+                               expected.put(psh.getDelegateStateHandle(), 
offsMapWithList);
+                       }
+                       taskIndex++;
                }
 
                OperatorStateRepartitioner repartitioner = 
RoundRobinOperatorStateRepartitioner.INSTANCE;

http://git-wip-us.apache.org/repos/asf/flink/blob/484fedd4/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java
index acedb50..d1d67ff 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java
@@ -97,7 +97,7 @@ public class CheckpointTestUtils {
                                Map<String, StateMetaInfo> offsetsMap = new 
HashMap<>();
                                offsetsMap.put("A", new 
OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, 
OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
                                offsetsMap.put("B", new 
OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, 
OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
-                               offsetsMap.put("C", new 
OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, 
OperatorStateHandle.Mode.BROADCAST));
+                               offsetsMap.put("C", new 
OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, 
OperatorStateHandle.Mode.UNION));
 
                                if (hasOperatorStateBackend) {
                                        operatorStateHandleBackend = new 
OperatorStateHandle(offsetsMap, operatorStateBackend);
@@ -179,7 +179,7 @@ public class CheckpointTestUtils {
                                        Map<String, StateMetaInfo> offsetsMap = 
new HashMap<>();
                                        offsetsMap.put("A", new 
OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, 
OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
                                        offsetsMap.put("B", new 
OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, 
OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
-                                       offsetsMap.put("C", new 
OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, 
OperatorStateHandle.Mode.BROADCAST));
+                                       offsetsMap.put("C", new 
OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, 
OperatorStateHandle.Mode.UNION));
 
                                        if (chainIdx != 
noOperatorStateBackendAtIndex) {
                                                OperatorStateHandle 
operatorStateHandleBackend =

http://git-wip-us.apache.org/repos/asf/flink/blob/484fedd4/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
index ef390db..1881dad 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
@@ -18,8 +18,11 @@
 package org.apache.flink.runtime.state;
 
 import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.state.BroadcastState;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeutils.CompatibilityResult;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot;
@@ -49,7 +52,9 @@ import java.io.File;
 import java.io.IOException;
 import java.io.Serializable;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.Iterator;
+import java.util.Map;
 import java.util.concurrent.CancellationException;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
@@ -85,6 +90,7 @@ public class OperatorStateBackendTest {
 
                assertNotNull(operatorStateBackend);
                
assertTrue(operatorStateBackend.getRegisteredStateNames().isEmpty());
+               
assertTrue(operatorStateBackend.getRegisteredBroadcastStateNames().isEmpty());
        }
 
        @Test
@@ -233,6 +239,20 @@ public class OperatorStateBackendTest {
 
                listState.add(42);
 
+               AtomicInteger keyCopyCounter = new AtomicInteger(0);
+               AtomicInteger valueCopyCounter = new AtomicInteger(0);
+
+               TypeSerializer<Integer> keySerializer = new 
VerifyingIntSerializer(env.getUserClassLoader(), keyCopyCounter);
+               TypeSerializer<Integer> valueSerializer = new 
VerifyingIntSerializer(env.getUserClassLoader(), valueCopyCounter);
+
+               MapStateDescriptor<Integer, Integer> broadcastStateDesc = new 
MapStateDescriptor<>(
+                               "test-broadcast", keySerializer, 
valueSerializer);
+
+               BroadcastState<Integer, Integer> broadcastState = 
operatorStateBackend.getBroadcastState(broadcastStateDesc);
+               broadcastState.put(1, 2);
+               broadcastState.put(3, 4);
+               broadcastState.put(5, 6);
+
                CheckpointStreamFactory streamFactory = new 
MemCheckpointStreamFactory(4096);
                RunnableFuture<OperatorStateHandle> runnableFuture =
                        operatorStateBackend.snapshot(1, 1, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation());
@@ -240,6 +260,8 @@ public class OperatorStateBackendTest {
 
                // make sure that the copy method has been called
                assertTrue(copyCounter.get() > 0);
+               assertTrue(keyCopyCounter.get() > 0);
+               assertTrue(valueCopyCounter.get() > 0);
        }
 
        /**
@@ -361,17 +383,102 @@ public class OperatorStateBackendTest {
        }
 
        @Test
+       public void testSnapshotBroadcastStateWithEmptyOperatorState() throws 
Exception {
+               final AbstractStateBackend abstractStateBackend = new 
MemoryStateBackend(4096);
+
+               final OperatorStateBackend operatorStateBackend =
+                               
abstractStateBackend.createOperatorStateBackend(createMockEnvironment(), 
"testOperator");
+
+               final MapStateDescriptor<Integer, Integer> broadcastStateDesc = 
new MapStateDescriptor<>(
+                               "test-broadcast", BasicTypeInfo.INT_TYPE_INFO, 
BasicTypeInfo.INT_TYPE_INFO);
+
+               final Map<Integer, Integer> expected = new HashMap<>(3);
+               expected.put(1, 2);
+               expected.put(3, 4);
+               expected.put(5, 6);
+
+               final BroadcastState<Integer, Integer> broadcastState = 
operatorStateBackend.getBroadcastState(broadcastStateDesc);
+               broadcastState.putAll(expected);
+
+               final CheckpointStreamFactory streamFactory = new 
MemCheckpointStreamFactory(4096);
+               OperatorStateHandle stateHandle = null;
+
+               try {
+                       RunnableFuture<OperatorStateHandle> snapshot =
+                                       operatorStateBackend.snapshot(0L, 0L, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());
+
+                       stateHandle = FutureUtil.runIfNotDoneAndGet(snapshot);
+                       assertNotNull(stateHandle);
+
+                       final Map<Integer, Integer> retrieved = new HashMap<>();
+
+                       
operatorStateBackend.restore(Collections.singleton(stateHandle));
+                       BroadcastState<Integer, Integer> retrievedState = 
operatorStateBackend.getBroadcastState(broadcastStateDesc);
+                       for (Map.Entry<Integer, Integer> e: 
retrievedState.entries()) {
+                               retrieved.put(e.getKey(), e.getValue());
+                       }
+                       assertEquals(expected, retrieved);
+
+                       // remove an element from both expected and stored 
state.
+                       broadcastState.remove(1);
+                       expected.remove(1);
+
+                       snapshot = operatorStateBackend.snapshot(1L, 1L, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());
+                       stateHandle = FutureUtil.runIfNotDoneAndGet(snapshot);
+
+                       retrieved.clear();
+                       
operatorStateBackend.restore(Collections.singleton(stateHandle));
+                       retrievedState = 
operatorStateBackend.getBroadcastState(broadcastStateDesc);
+                       for (Map.Entry<Integer, Integer> e: 
retrievedState.immutableEntries()) {
+                               retrieved.put(e.getKey(), e.getValue());
+                       }
+                       assertEquals(expected, retrieved);
+
+                       // remove all elements from both expected and stored 
state.
+                       broadcastState.clear();
+                       expected.clear();
+
+                       snapshot = operatorStateBackend.snapshot(2L, 2L, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());
+                       stateHandle = FutureUtil.runIfNotDoneAndGet(snapshot);
+
+                       retrieved.clear();
+                       
operatorStateBackend.restore(Collections.singleton(stateHandle));
+                       retrievedState = 
operatorStateBackend.getBroadcastState(broadcastStateDesc);
+                       for (Map.Entry<Integer, Integer> e: 
retrievedState.immutableEntries()) {
+                               retrieved.put(e.getKey(), e.getValue());
+                       }
+                       assertTrue(expected.isEmpty());
+                       assertEquals(expected, retrieved);
+               } finally {
+                       operatorStateBackend.close();
+                       operatorStateBackend.dispose();
+                       if (stateHandle != null) {
+                               stateHandle.discardState();
+                       }
+               }
+       }
+
+       @Test
        public void testSnapshotRestoreSync() throws Exception {
-               AbstractStateBackend abstractStateBackend = new 
MemoryStateBackend(4096);
+               AbstractStateBackend abstractStateBackend = new 
MemoryStateBackend(2 * 4096);
 
                OperatorStateBackend operatorStateBackend = 
abstractStateBackend.createOperatorStateBackend(createMockEnvironment(), 
"test-op-name");
                ListStateDescriptor<Serializable> stateDescriptor1 = new 
ListStateDescriptor<>("test1", new JavaSerializer<>());
                ListStateDescriptor<Serializable> stateDescriptor2 = new 
ListStateDescriptor<>("test2", new JavaSerializer<>());
                ListStateDescriptor<Serializable> stateDescriptor3 = new 
ListStateDescriptor<>("test3", new JavaSerializer<>());
+
+               MapStateDescriptor<Serializable, Serializable> 
broadcastStateDescriptor1 = new MapStateDescriptor<>("test4", new 
JavaSerializer<>(), new JavaSerializer<>());
+               MapStateDescriptor<Serializable, Serializable> 
broadcastStateDescriptor2 = new MapStateDescriptor<>("test5", new 
JavaSerializer<>(), new JavaSerializer<>());
+               MapStateDescriptor<Serializable, Serializable> 
broadcastStateDescriptor3 = new MapStateDescriptor<>("test6", new 
JavaSerializer<>(), new JavaSerializer<>());
+
                ListState<Serializable> listState1 = 
operatorStateBackend.getListState(stateDescriptor1);
                ListState<Serializable> listState2 = 
operatorStateBackend.getListState(stateDescriptor2);
                ListState<Serializable> listState3 = 
operatorStateBackend.getUnionListState(stateDescriptor3);
 
+               BroadcastState<Serializable, Serializable> broadcastState1 = 
operatorStateBackend.getBroadcastState(broadcastStateDescriptor1);
+               BroadcastState<Serializable, Serializable> broadcastState2 = 
operatorStateBackend.getBroadcastState(broadcastStateDescriptor2);
+               BroadcastState<Serializable, Serializable> broadcastState3 = 
operatorStateBackend.getBroadcastState(broadcastStateDescriptor3);
+
                listState1.add(42);
                listState1.add(4711);
 
@@ -384,7 +491,12 @@ public class OperatorStateBackendTest {
                listState3.add(19);
                listState3.add(20);
 
-               CheckpointStreamFactory streamFactory = new 
MemCheckpointStreamFactory(4096);
+               broadcastState1.put(1, 2);
+               broadcastState1.put(2, 5);
+
+               broadcastState2.put(2, 5);
+
+               CheckpointStreamFactory streamFactory = new 
MemCheckpointStreamFactory(2 * 4096);
                RunnableFuture<OperatorStateHandle> runnableFuture =
                                operatorStateBackend.snapshot(1, 1, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());
                OperatorStateHandle stateHandle = 
FutureUtil.runIfNotDoneAndGet(runnableFuture);
@@ -401,12 +513,18 @@ public class OperatorStateBackendTest {
                        
operatorStateBackend.restore(Collections.singletonList(stateHandle));
 
                        assertEquals(3, 
operatorStateBackend.getRegisteredStateNames().size());
+                       assertEquals(3, 
operatorStateBackend.getRegisteredBroadcastStateNames().size());
 
                        listState1 = 
operatorStateBackend.getListState(stateDescriptor1);
                        listState2 = 
operatorStateBackend.getListState(stateDescriptor2);
                        listState3 = 
operatorStateBackend.getUnionListState(stateDescriptor3);
 
+                       broadcastState1 = 
operatorStateBackend.getBroadcastState(broadcastStateDescriptor1);
+                       broadcastState2 = 
operatorStateBackend.getBroadcastState(broadcastStateDescriptor2);
+                       broadcastState3 = 
operatorStateBackend.getBroadcastState(broadcastStateDescriptor3);
+
                        assertEquals(3, 
operatorStateBackend.getRegisteredStateNames().size());
+                       assertEquals(3, 
operatorStateBackend.getRegisteredBroadcastStateNames().size());
 
                        Iterator<Serializable> it = listState1.get().iterator();
                        assertEquals(42, it.next());
@@ -426,6 +544,27 @@ public class OperatorStateBackendTest {
                        assertEquals(20, it.next());
                        assertFalse(it.hasNext());
 
+                       Iterator<Map.Entry<Serializable, Serializable>> bIt = 
broadcastState1.iterator();
+                       assertTrue(bIt.hasNext());
+                       Map.Entry<Serializable, Serializable> entry = 
bIt.next();
+                       assertEquals(1, entry.getKey());
+                       assertEquals(2, entry.getValue());
+                       assertTrue(bIt.hasNext());
+                       entry = bIt.next();
+                       assertEquals(2, entry.getKey());
+                       assertEquals(5, entry.getValue());
+                       assertFalse(bIt.hasNext());
+
+                       bIt = broadcastState2.iterator();
+                       assertTrue(bIt.hasNext());
+                       entry = bIt.next();
+                       assertEquals(2, entry.getKey());
+                       assertEquals(5, entry.getValue());
+                       assertFalse(bIt.hasNext());
+
+                       bIt = broadcastState3.iterator();
+                       assertFalse(bIt.hasNext());
+
                        operatorStateBackend.close();
                        operatorStateBackend.dispose();
                } finally {
@@ -444,10 +583,22 @@ public class OperatorStateBackendTest {
                                new ListStateDescriptor<>("test2", new 
JavaSerializer<MutableType>());
                ListStateDescriptor<MutableType> stateDescriptor3 =
                                new ListStateDescriptor<>("test3", new 
JavaSerializer<MutableType>());
+
+               MapStateDescriptor<MutableType, MutableType> 
broadcastStateDescriptor1 =
+                               new MapStateDescriptor<>("test4", new 
JavaSerializer<MutableType>(), new JavaSerializer<MutableType>());
+               MapStateDescriptor<MutableType, MutableType> 
broadcastStateDescriptor2 =
+                               new MapStateDescriptor<>("test5", new 
JavaSerializer<MutableType>(), new JavaSerializer<MutableType>());
+               MapStateDescriptor<MutableType, MutableType> 
broadcastStateDescriptor3 =
+                               new MapStateDescriptor<>("test6", new 
JavaSerializer<MutableType>(), new JavaSerializer<MutableType>());
+
                ListState<MutableType> listState1 = 
operatorStateBackend.getListState(stateDescriptor1);
                ListState<MutableType> listState2 = 
operatorStateBackend.getListState(stateDescriptor2);
                ListState<MutableType> listState3 = 
operatorStateBackend.getUnionListState(stateDescriptor3);
 
+               BroadcastState<MutableType, MutableType> broadcastState1 = 
operatorStateBackend.getBroadcastState(broadcastStateDescriptor1);
+               BroadcastState<MutableType, MutableType> broadcastState2 = 
operatorStateBackend.getBroadcastState(broadcastStateDescriptor2);
+               BroadcastState<MutableType, MutableType> broadcastState3 = 
operatorStateBackend.getBroadcastState(broadcastStateDescriptor3);
+
                listState1.add(MutableType.of(42));
                listState1.add(MutableType.of(4711));
 
@@ -460,6 +611,11 @@ public class OperatorStateBackendTest {
                listState3.add(MutableType.of(19));
                listState3.add(MutableType.of(20));
 
+               broadcastState1.put(MutableType.of(1), MutableType.of(2));
+               broadcastState1.put(MutableType.of(2), MutableType.of(5));
+
+               broadcastState2.put(MutableType.of(2), MutableType.of(5));
+
                BlockerCheckpointStreamFactory streamFactory = new 
BlockerCheckpointStreamFactory(1024 * 1024);
 
                OneShotLatch waiterLatch = new OneShotLatch();
@@ -482,6 +638,8 @@ public class OperatorStateBackendTest {
 
                listState1.add(MutableType.of(77));
 
+               broadcastState1.put(MutableType.of(32), MutableType.of(97));
+
                int n = 0;
 
                for (MutableType mutableType : listState2.get()) {
@@ -493,6 +651,7 @@ public class OperatorStateBackendTest {
                }
 
                listState3.clear();
+               broadcastState2.clear();
 
                operatorStateBackend.getListState(
                                new ListStateDescriptor<>("test4", new 
JavaSerializer<MutableType>()));
@@ -514,12 +673,18 @@ public class OperatorStateBackendTest {
                        
operatorStateBackend.restore(Collections.singletonList(stateHandle));
 
                        assertEquals(3, 
operatorStateBackend.getRegisteredStateNames().size());
+                       assertEquals(3, 
operatorStateBackend.getRegisteredBroadcastStateNames().size());
 
                        listState1 = 
operatorStateBackend.getListState(stateDescriptor1);
                        listState2 = 
operatorStateBackend.getListState(stateDescriptor2);
                        listState3 = 
operatorStateBackend.getUnionListState(stateDescriptor3);
 
+                       broadcastState1 = 
operatorStateBackend.getBroadcastState(broadcastStateDescriptor1);
+                       broadcastState2 = 
operatorStateBackend.getBroadcastState(broadcastStateDescriptor2);
+                       broadcastState3 = 
operatorStateBackend.getBroadcastState(broadcastStateDescriptor3);
+
                        assertEquals(3, 
operatorStateBackend.getRegisteredStateNames().size());
+                       assertEquals(3, 
operatorStateBackend.getRegisteredBroadcastStateNames().size());
 
                        Iterator<MutableType> it = listState1.get().iterator();
                        assertEquals(42, it.next().value);
@@ -539,6 +704,27 @@ public class OperatorStateBackendTest {
                        assertEquals(20, it.next().value);
                        assertFalse(it.hasNext());
 
+                       Iterator<Map.Entry<MutableType, MutableType>> bIt = 
broadcastState1.iterator();
+                       assertTrue(bIt.hasNext());
+                       Map.Entry<MutableType, MutableType> entry = bIt.next();
+                       assertEquals(1, entry.getKey().value);
+                       assertEquals(2, entry.getValue().value);
+                       assertTrue(bIt.hasNext());
+                       entry = bIt.next();
+                       assertEquals(2, entry.getKey().value);
+                       assertEquals(5, entry.getValue().value);
+                       assertFalse(bIt.hasNext());
+
+                       bIt = broadcastState2.iterator();
+                       assertTrue(bIt.hasNext());
+                       entry = bIt.next();
+                       assertEquals(2, entry.getKey().value);
+                       assertEquals(5, entry.getValue().value);
+                       assertFalse(bIt.hasNext());
+
+                       bIt = broadcastState3.iterator();
+                       assertFalse(bIt.hasNext());
+
                        operatorStateBackend.close();
                        operatorStateBackend.dispose();
                } finally {
@@ -558,10 +744,16 @@ public class OperatorStateBackendTest {
 
                ListState<MutableType> listState1 = 
operatorStateBackend.getOperatorState(stateDescriptor1);
 
-
                listState1.add(MutableType.of(42));
                listState1.add(MutableType.of(4711));
 
+               MapStateDescriptor<MutableType, MutableType> 
broadcastStateDescriptor1 =
+                               new MapStateDescriptor<>("test4", new 
JavaSerializer<MutableType>(), new JavaSerializer<MutableType>());
+
+               BroadcastState<MutableType, MutableType> broadcastState1 = 
operatorStateBackend.getBroadcastState(broadcastStateDescriptor1);
+               broadcastState1.put(MutableType.of(1), MutableType.of(2));
+               broadcastState1.put(MutableType.of(2), MutableType.of(5));
+
                BlockerCheckpointStreamFactory streamFactory = new 
BlockerCheckpointStreamFactory(1024 * 1024);
 
                OneShotLatch waiterLatch = new OneShotLatch();
@@ -602,7 +794,6 @@ public class OperatorStateBackendTest {
 
                ListState<MutableType> listState1 = 
operatorStateBackend.getOperatorState(stateDescriptor1);
 
-
                listState1.add(MutableType.of(42));
                listState1.add(MutableType.of(4711));
 

http://git-wip-us.apache.org/repos/asf/flink/blob/484fedd4/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java
index ab801b6..88f9cd7 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java
@@ -28,10 +28,11 @@ public class OperatorStateHandleTest {
 
                // Ensure the order / ordinal of all values of enum 'mode' are 
fixed, as this is used for serialization
                Assert.assertEquals(0, 
OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.ordinal());
-               Assert.assertEquals(1, 
OperatorStateHandle.Mode.BROADCAST.ordinal());
+               Assert.assertEquals(1, 
OperatorStateHandle.Mode.UNION.ordinal());
+               Assert.assertEquals(2, 
OperatorStateHandle.Mode.BROADCAST.ordinal());
 
                // Ensure all enum values are registered and fixed forever by 
this test
-               Assert.assertEquals(2, 
OperatorStateHandle.Mode.values().length);
+               Assert.assertEquals(3, 
OperatorStateHandle.Mode.values().length);
 
                // Byte is used to encode enum value on serialization
                Assert.assertTrue(OperatorStateHandle.Mode.values().length <= 
Byte.MAX_VALUE);

http://git-wip-us.apache.org/repos/asf/flink/blob/484fedd4/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
index 341d4fe..57e4aed 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
@@ -24,6 +24,7 @@ import 
org.apache.flink.api.common.typeutils.TypeSerializerSerializationUtil;
 import org.apache.flink.api.common.typeutils.base.DoubleSerializer;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
 import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
@@ -205,6 +206,8 @@ public class SerializationProxiesTest {
        public void testOperatorBackendSerializationProxyRoundtrip() throws 
Exception {
 
                TypeSerializer<?> stateSerializer = DoubleSerializer.INSTANCE;
+               TypeSerializer<?> keySerializer = DoubleSerializer.INSTANCE;
+               TypeSerializer<?> valueSerializer = StringSerializer.INSTANCE;
 
                List<RegisteredOperatorBackendStateMetaInfo.Snapshot<?>> 
stateMetaInfoSnapshots = new ArrayList<>();
 
@@ -213,10 +216,17 @@ public class SerializationProxiesTest {
                stateMetaInfoSnapshots.add(new 
RegisteredOperatorBackendStateMetaInfo<>(
                        "b", stateSerializer, 
OperatorStateHandle.Mode.SPLIT_DISTRIBUTE).snapshot());
                stateMetaInfoSnapshots.add(new 
RegisteredOperatorBackendStateMetaInfo<>(
-                       "c", stateSerializer, 
OperatorStateHandle.Mode.BROADCAST).snapshot());
+                       "c", stateSerializer, 
OperatorStateHandle.Mode.UNION).snapshot());
+
+               List<RegisteredBroadcastBackendStateMetaInfo.Snapshot<?, ?>> 
broadcastStateMetaInfoSnapshots = new ArrayList<>();
+
+               broadcastStateMetaInfoSnapshots.add(new 
RegisteredBroadcastBackendStateMetaInfo<>(
+                               "d", OperatorStateHandle.Mode.BROADCAST, 
keySerializer, valueSerializer).snapshot());
+               broadcastStateMetaInfoSnapshots.add(new 
RegisteredBroadcastBackendStateMetaInfo<>(
+                               "e", OperatorStateHandle.Mode.BROADCAST, 
valueSerializer, keySerializer).snapshot());
 
                OperatorBackendSerializationProxy serializationProxy =
-                               new 
OperatorBackendSerializationProxy(stateMetaInfoSnapshots);
+                               new 
OperatorBackendSerializationProxy(stateMetaInfoSnapshots, 
broadcastStateMetaInfoSnapshots);
 
                byte[] serialized;
                try (ByteArrayOutputStreamWithPos out = new 
ByteArrayOutputStreamWithPos()) {
@@ -231,7 +241,8 @@ public class SerializationProxiesTest {
                        serializationProxy.read(new 
DataInputViewStreamWrapper(in));
                }
 
-               Assert.assertEquals(stateMetaInfoSnapshots, 
serializationProxy.getStateMetaInfoSnapshots());
+               Assert.assertEquals(stateMetaInfoSnapshots, 
serializationProxy.getOperatorStateMetaInfoSnapshots());
+               Assert.assertEquals(broadcastStateMetaInfoSnapshots, 
serializationProxy.getBroadcastStateMetaInfoSnapshots());
        }
 
        @Test
@@ -242,24 +253,58 @@ public class SerializationProxiesTest {
 
                RegisteredOperatorBackendStateMetaInfo.Snapshot<?> metaInfo =
                        new RegisteredOperatorBackendStateMetaInfo<>(
-                               name, stateSerializer, 
OperatorStateHandle.Mode.BROADCAST).snapshot();
+                               name, stateSerializer, 
OperatorStateHandle.Mode.UNION).snapshot();
 
                byte[] serialized;
                try (ByteArrayOutputStreamWithPos out = new 
ByteArrayOutputStreamWithPos()) {
                        OperatorBackendStateMetaInfoSnapshotReaderWriters
-                               
.getWriterForVersion(OperatorBackendSerializationProxy.VERSION, metaInfo)
-                               .writeStateMetaInfo(new 
DataOutputViewStreamWrapper(out));
+                               
.getOperatorStateWriterForVersion(OperatorBackendSerializationProxy.VERSION, 
metaInfo)
+                               .writeOperatorStateMetaInfo(new 
DataOutputViewStreamWrapper(out));
 
                        serialized = out.toByteArray();
                }
 
                try (ByteArrayInputStreamWithPos in = new 
ByteArrayInputStreamWithPos(serialized)) {
                        metaInfo = 
OperatorBackendStateMetaInfoSnapshotReaderWriters
-                               
.getReaderForVersion(OperatorBackendSerializationProxy.VERSION, 
Thread.currentThread().getContextClassLoader())
-                               .readStateMetaInfo(new 
DataInputViewStreamWrapper(in));
+                               
.getOperatorStateReaderForVersion(OperatorBackendSerializationProxy.VERSION, 
Thread.currentThread().getContextClassLoader())
+                               .readOperatorStateMetaInfo(new 
DataInputViewStreamWrapper(in));
+               }
+
+               Assert.assertEquals(name, metaInfo.getName());
+               Assert.assertEquals(OperatorStateHandle.Mode.UNION, 
metaInfo.getAssignmentMode());
+               Assert.assertEquals(stateSerializer, 
metaInfo.getPartitionStateSerializer());
+       }
+
+       @Test
+       public void testBroadcastStateMetaInfoSerialization() throws Exception {
+
+               String name = "test";
+               TypeSerializer<?> keySerializer = DoubleSerializer.INSTANCE;
+               TypeSerializer<?> valueSerializer = StringSerializer.INSTANCE;
+
+               RegisteredBroadcastBackendStateMetaInfo.Snapshot<?, ?> metaInfo 
=
+                               new RegisteredBroadcastBackendStateMetaInfo<>(
+                                               name, 
OperatorStateHandle.Mode.BROADCAST, keySerializer, valueSerializer).snapshot();
+
+               byte[] serialized;
+               try (ByteArrayOutputStreamWithPos out = new 
ByteArrayOutputStreamWithPos()) {
+                       OperatorBackendStateMetaInfoSnapshotReaderWriters
+                                       
.getBroadcastStateWriterForVersion(OperatorBackendSerializationProxy.VERSION, 
metaInfo)
+                                       .writeBroadcastStateMetaInfo(new 
DataOutputViewStreamWrapper(out));
+
+                       serialized = out.toByteArray();
+               }
+
+               try (ByteArrayInputStreamWithPos in = new 
ByteArrayInputStreamWithPos(serialized)) {
+                       metaInfo = 
OperatorBackendStateMetaInfoSnapshotReaderWriters
+                                       
.getBroadcastStateReaderForVersion(OperatorBackendSerializationProxy.VERSION, 
Thread.currentThread().getContextClassLoader())
+                                       .readBroadcastStateMetaInfo(new 
DataInputViewStreamWrapper(in));
                }
 
                Assert.assertEquals(name, metaInfo.getName());
+               Assert.assertEquals(OperatorStateHandle.Mode.BROADCAST, 
metaInfo.getAssignmentMode());
+               Assert.assertEquals(keySerializer, metaInfo.getKeySerializer());
+               Assert.assertEquals(valueSerializer, 
metaInfo.getValueSerializer());
        }
 
        @Test
@@ -269,13 +314,13 @@ public class SerializationProxiesTest {
 
                RegisteredOperatorBackendStateMetaInfo.Snapshot<?> metaInfo =
                        new RegisteredOperatorBackendStateMetaInfo<>(
-                               name, stateSerializer, 
OperatorStateHandle.Mode.BROADCAST).snapshot();
+                               name, stateSerializer, 
OperatorStateHandle.Mode.UNION).snapshot();
 
                byte[] serialized;
                try (ByteArrayOutputStreamWithPos out = new 
ByteArrayOutputStreamWithPos()) {
                        OperatorBackendStateMetaInfoSnapshotReaderWriters
-                               
.getWriterForVersion(OperatorBackendSerializationProxy.VERSION, metaInfo)
-                               .writeStateMetaInfo(new 
DataOutputViewStreamWrapper(out));
+                               
.getOperatorStateWriterForVersion(OperatorBackendSerializationProxy.VERSION, 
metaInfo)
+                               .writeOperatorStateMetaInfo(new 
DataOutputViewStreamWrapper(out));
 
                        serialized = out.toByteArray();
                }
@@ -288,8 +333,8 @@ public class SerializationProxiesTest {
 
                try (ByteArrayInputStreamWithPos in = new 
ByteArrayInputStreamWithPos(serialized)) {
                        metaInfo = 
OperatorBackendStateMetaInfoSnapshotReaderWriters
-                               
.getReaderForVersion(OperatorBackendSerializationProxy.VERSION, 
Thread.currentThread().getContextClassLoader())
-                               .readStateMetaInfo(new 
DataInputViewStreamWrapper(in));
+                               
.getOperatorStateReaderForVersion(OperatorBackendSerializationProxy.VERSION, 
Thread.currentThread().getContextClassLoader())
+                               .readOperatorStateMetaInfo(new 
DataInputViewStreamWrapper(in));
                }
 
                Assert.assertEquals(name, metaInfo.getName());
@@ -297,6 +342,44 @@ public class SerializationProxiesTest {
                Assert.assertEquals(stateSerializer.snapshotConfiguration(), 
metaInfo.getPartitionStateSerializerConfigSnapshot());
        }
 
+       @Test
+       public void testBroadcastStateMetaInfoReadSerializerFailureResilience() 
throws Exception {
+               String broadcastName = "broadcastTest";
+               TypeSerializer<?> keySerializer = DoubleSerializer.INSTANCE;
+               TypeSerializer<?> valueSerializer = StringSerializer.INSTANCE;
+
+               RegisteredBroadcastBackendStateMetaInfo.Snapshot<?, ?> 
broadcastMetaInfo =
+                               new RegisteredBroadcastBackendStateMetaInfo<>(
+                                               broadcastName, 
OperatorStateHandle.Mode.BROADCAST, keySerializer, valueSerializer).snapshot();
+
+               byte[] serialized;
+               try (ByteArrayOutputStreamWithPos out = new 
ByteArrayOutputStreamWithPos()) {
+                       OperatorBackendStateMetaInfoSnapshotReaderWriters
+                                       
.getBroadcastStateWriterForVersion(OperatorBackendSerializationProxy.VERSION, 
broadcastMetaInfo)
+                                       .writeBroadcastStateMetaInfo(new 
DataOutputViewStreamWrapper(out));
+
+                       serialized = out.toByteArray();
+               }
+
+               // mock failure when deserializing serializer
+               
TypeSerializerSerializationUtil.TypeSerializerSerializationProxy<?> mockProxy =
+                               
mock(TypeSerializerSerializationUtil.TypeSerializerSerializationProxy.class);
+               doThrow(new 
IOException()).when(mockProxy).read(any(DataInputViewStreamWrapper.class));
+               
PowerMockito.whenNew(TypeSerializerSerializationUtil.TypeSerializerSerializationProxy.class).withAnyArguments().thenReturn(mockProxy);
+
+               try (ByteArrayInputStreamWithPos in = new 
ByteArrayInputStreamWithPos(serialized)) {
+                       broadcastMetaInfo = 
OperatorBackendStateMetaInfoSnapshotReaderWriters
+                                       
.getBroadcastStateReaderForVersion(OperatorBackendSerializationProxy.VERSION, 
Thread.currentThread().getContextClassLoader())
+                                       .readBroadcastStateMetaInfo(new 
DataInputViewStreamWrapper(in));
+               }
+
+               Assert.assertEquals(broadcastName, broadcastMetaInfo.getName());
+               Assert.assertEquals(null, broadcastMetaInfo.getKeySerializer());
+               Assert.assertEquals(keySerializer.snapshotConfiguration(), 
broadcastMetaInfo.getKeySerializerConfigSnapshot());
+               Assert.assertEquals(null, 
broadcastMetaInfo.getValueSerializer());
+               Assert.assertEquals(valueSerializer.snapshotConfiguration(), 
broadcastMetaInfo.getValueSerializerConfigSnapshot());
+       }
+
        /**
         * This test fixes the order of elements in the enum which is important 
for serialization. Do not modify this test
         * except if you are entirely sure what you are doing.

Reply via email to