http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 9adaa86..c39e436 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -20,7 +20,9 @@ package org.apache.flink.runtime.checkpoint;
 
 import com.google.common.collect.Iterables;
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.checkpoint.savepoint.HeapSavepointStore;
 import 
org.apache.flink.runtime.checkpoint.stats.DisabledCheckpointStatsTracker;
 import org.apache.flink.runtime.execution.ExecutionState;
@@ -34,21 +36,21 @@ import 
org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
-
 import org.junit.Assert;
 import org.junit.Test;
-
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
-
 import scala.concurrent.ExecutionContext;
 import scala.concurrent.Future;
 
@@ -56,6 +58,8 @@ import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
@@ -1459,7 +1463,7 @@ public class CheckpointCoordinatorTest {
                                        maxConcurrentAttempts,
                                        new ExecutionVertex[] { triggerVertex },
                                        new ExecutionVertex[] { ackVertex },
-                                       new ExecutionVertex[] { commitVertex }, 
+                                       new ExecutionVertex[] { commitVertex },
                                        new StandaloneCheckpointIDCounter(),
                                        new 
StandaloneCompletedCheckpointStore(2, cl),
                                        new HeapSavepointStore(),
@@ -1531,7 +1535,7 @@ public class CheckpointCoordinatorTest {
                                        maxConcurrentAttempts, // max two 
concurrent checkpoints
                                        new ExecutionVertex[] { triggerVertex },
                                        new ExecutionVertex[] { ackVertex },
-                                       new ExecutionVertex[] { commitVertex }, 
+                                       new ExecutionVertex[] { commitVertex },
                                        new StandaloneCheckpointIDCounter(),
                                        new 
StandaloneCompletedCheckpointStore(2, cl),
                                        new HeapSavepointStore(),
@@ -1791,29 +1795,29 @@ public class CheckpointCoordinatorTest {
 
                for (int index = 0; index < jobVertex1.getParallelism(); 
index++) {
                        ChainedStateHandle<StreamStateHandle> 
nonPartitionedState = generateStateForVertex(jobVertexID1, index);
+                       ChainedStateHandle<OperatorStateHandle> 
partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, 
index, 2, 8);
                        List<KeyGroupsStateHandle> partitionedKeyGroupState = 
generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
 
+                       CheckpointStateHandles checkpointStateHandles = new 
CheckpointStateHandles(nonPartitionedState, partitionableState, 
partitionedKeyGroupState);
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
-                               jid,
-                               
jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-                               checkpointId,
-                               nonPartitionedState,
-                               partitionedKeyGroupState);
+                                       jid,
+                                       
jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+                                       checkpointId,
+                                       checkpointStateHandles);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
 
-
                for (int index = 0; index < jobVertex2.getParallelism(); 
index++) {
                        ChainedStateHandle<StreamStateHandle> 
nonPartitionedState = generateStateForVertex(jobVertexID2, index);
+                       ChainedStateHandle<OperatorStateHandle> 
partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, 
index, 2, 8);
                        List<KeyGroupsStateHandle> partitionedKeyGroupState = 
generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-
+                       CheckpointStateHandles checkpointStateHandles = new 
CheckpointStateHandles(nonPartitionedState, partitionableState, 
partitionedKeyGroupState);
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
-                               jid,
-                               
jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-                               checkpointId,
-                               nonPartitionedState,
-                               partitionedKeyGroupState);
+                                       jid,
+                                       
jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+                                       checkpointId,
+                                       checkpointStateHandles);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
@@ -1895,13 +1899,12 @@ public class CheckpointCoordinatorTest {
                for (int index = 0; index < jobVertex1.getParallelism(); 
index++) {
                        ChainedStateHandle<StreamStateHandle> valueSizeTuple = 
generateStateForVertex(jobVertexID1, index);
                        List<KeyGroupsStateHandle> keyGroupState = 
generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
-
+                       CheckpointStateHandles checkpointStateHandles = new 
CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
-                               jid,
-                               
jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-                               checkpointId,
-                               valueSizeTuple,
-                               keyGroupState);
+                                       jid,
+                                       
jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+                                       checkpointId,
+                                       checkpointStateHandles);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
@@ -1910,13 +1913,12 @@ public class CheckpointCoordinatorTest {
                for (int index = 0; index < jobVertex2.getParallelism(); 
index++) {
                        ChainedStateHandle<StreamStateHandle> valueSizeTuple = 
generateStateForVertex(jobVertexID2, index);
                        List<KeyGroupsStateHandle> keyGroupState = 
generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-
+                       CheckpointStateHandles checkpointStateHandles = new 
CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
-                               jid,
-                               
jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-                               checkpointId,
-                               valueSizeTuple,
-                               keyGroupState);
+                                       jid,
+                                       
jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+                                       checkpointId,
+                                       checkpointStateHandles);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
@@ -2014,12 +2016,12 @@ public class CheckpointCoordinatorTest {
                        List<KeyGroupsStateHandle> keyGroupState = 
generateKeyGroupState(
                                        jobVertexID1, 
keyGroupPartitions1.get(index));
 
+                       CheckpointStateHandles checkpointStateHandles = new 
CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
-                               jid,
-                               
jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-                               checkpointId,
-                               valueSizeTuple,
-                               keyGroupState);
+                                       jid,
+                                       
jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+                                       checkpointId,
+                                       checkpointStateHandles);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
@@ -2031,12 +2033,12 @@ public class CheckpointCoordinatorTest {
                        List<KeyGroupsStateHandle> keyGroupState = 
generateKeyGroupState(
                                        jobVertexID2, 
keyGroupPartitions2.get(index));
 
+                       CheckpointStateHandles checkpointStateHandles = new 
CheckpointStateHandles(state, null, keyGroupState);
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
                                        jid,
                                        
jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
                                        checkpointId,
-                                       state,
-                                       keyGroupState);
+                                       checkpointStateHandles);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
@@ -2132,28 +2134,32 @@ public class CheckpointCoordinatorTest {
 
                for (int index = 0; index < jobVertex1.getParallelism(); 
index++) {
                        ChainedStateHandle<StreamStateHandle> valueSizeTuple = 
generateStateForVertex(jobVertexID1, index);
+                       ChainedStateHandle<OperatorStateHandle> 
partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, 
index, 2, 8);
                        List<KeyGroupsStateHandle> keyGroupState = 
generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
 
+
+                       CheckpointStateHandles checkpointStateHandles = new 
CheckpointStateHandles(valueSizeTuple, partitionableState, keyGroupState);
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
                                        jid,
                                        
jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
                                        checkpointId,
-                                       valueSizeTuple,
-                                       keyGroupState);
+                                       checkpointStateHandles);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
 
 
+               final List<ChainedStateHandle<OperatorStateHandle>> 
originalPartitionableStates = new ArrayList<>(jobVertex2.getParallelism());
                for (int index = 0; index < jobVertex2.getParallelism(); 
index++) {
                        List<KeyGroupsStateHandle> keyGroupState = 
generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-
+                       ChainedStateHandle<OperatorStateHandle> 
partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, 
index, 2, 8);
+                       originalPartitionableStates.add(partitionableState);
+                       CheckpointStateHandles checkpointStateHandles = new 
CheckpointStateHandles(null, partitionableState, keyGroupState);
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
                                        jid,
                                        
jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
                                        checkpointId,
-                                       null,
-                                       keyGroupState);
+                                       checkpointStateHandles);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
@@ -2185,22 +2191,49 @@ public class CheckpointCoordinatorTest {
 
                // verify the restored state
                verifiyStateRestore(jobVertexID1, newJobVertex1, 
keyGroupPartitions1);
-
+               List<List<Collection<OperatorStateHandle>>> 
actualPartitionableStates = new ArrayList<>(newJobVertex2.getParallelism());
                for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
                        List<KeyGroupsStateHandle> originalKeyGroupState = 
generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i));
 
                        ChainedStateHandle<StreamStateHandle> operatorState = 
newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
+                       List<Collection<OperatorStateHandle>> 
partitionableState = 
newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle();
                        List<KeyGroupsStateHandle> keyGroupState = 
newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles();
 
+                       actualPartitionableStates.add(partitionableState);
                        assertNull(operatorState);
-                       comparePartitionedState(originalKeyGroupState, 
keyGroupState);
+                       compareKeyPartitionedState(originalKeyGroupState, 
keyGroupState);
                }
+               comparePartitionableState(originalPartitionableStates, 
actualPartitionableStates);
        }
 
        // 
------------------------------------------------------------------------
        //  Utilities
        // 
------------------------------------------------------------------------
 
+       static void sendAckMessageToCoordinator(
+                       CheckpointCoordinator coord,
+                       long checkpointId, JobID jid,
+                       ExecutionJobVertex jobVertex,
+                       JobVertexID jobVertexID,
+                       List<KeyGroupRange> keyGroupPartitions) throws 
Exception {
+
+               for (int index = 0; index < jobVertex.getParallelism(); 
index++) {
+                       ChainedStateHandle<StreamStateHandle> state = 
generateStateForVertex(jobVertexID, index);
+                       List<KeyGroupsStateHandle> keyGroupState = 
generateKeyGroupState(
+                                       jobVertexID,
+                                       keyGroupPartitions.get(index));
+
+                       CheckpointStateHandles checkpointStateHandles = new 
CheckpointStateHandles(state, null, keyGroupState);
+                       AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
+                                       jid,
+                                       
jobVertex.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+                                       checkpointId,
+                                       checkpointStateHandles);
+
+                       coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+               }
+       }
+
        public static List<KeyGroupsStateHandle> generateKeyGroupState(
                        JobVertexID jobVertexID,
                        KeyGroupRange keyGroupPartition) throws IOException {
@@ -2217,23 +2250,45 @@ public class CheckpointCoordinatorTest {
                return generateKeyGroupState(keyGroupPartition, 
testStatesLists);
        }
 
-       public static List<KeyGroupsStateHandle> 
generateKeyGroupState(KeyGroupRange keyGroupRange, List< ? extends 
Serializable> states) throws IOException {
+       public static List<KeyGroupsStateHandle> generateKeyGroupState(
+                       KeyGroupRange keyGroupRange,
+                       List<? extends Serializable> states) throws IOException 
{
+
                
Preconditions.checkArgument(keyGroupRange.getNumberOfKeyGroups() == 
states.size());
 
-               long[] offsets = new long[keyGroupRange.getNumberOfKeyGroups()];
-               List<byte[]> serializedGroupValues = new 
ArrayList<>(offsets.length);
+               Tuple2<byte[], List<long[]>> serializedDataWithOffsets =
+                               
serializeTogetherAndTrackOffsets(Collections.<List<? extends 
Serializable>>singletonList(states));
+
+               KeyGroupRangeOffsets keyGroupRangeOffsets = new 
KeyGroupRangeOffsets(keyGroupRange, serializedDataWithOffsets.f1.get(0));
+
+               ByteStreamStateHandle allSerializedStatesHandle = new 
ByteStreamStateHandle(
+                               serializedDataWithOffsets.f0);
+               KeyGroupsStateHandle keyGroupsStateHandle = new 
KeyGroupsStateHandle(
+                               keyGroupRangeOffsets,
+                               allSerializedStatesHandle);
+               List<KeyGroupsStateHandle> keyGroupsStateHandleList = new 
ArrayList<>();
+               keyGroupsStateHandleList.add(keyGroupsStateHandle);
+               return keyGroupsStateHandleList;
+       }
+
+       public static Tuple2<byte[], List<long[]>> 
serializeTogetherAndTrackOffsets(
+                       List<List<? extends Serializable>> serializables) 
throws IOException {
 
-               KeyGroupRangeOffsets keyGroupRangeOffsets = new 
KeyGroupRangeOffsets(keyGroupRange, offsets);
+               List<long[]> offsets = new ArrayList<>(serializables.size());
+               List<byte[]> serializedGroupValues = new ArrayList<>();
 
                int runningGroupsOffset = 0;
-               // generate test state for all keygroups
-               int idx = 0;
-               for (int keyGroup : keyGroupRange) {
-                       
keyGroupRangeOffsets.setKeyGroupOffset(keyGroup,runningGroupsOffset);
-                       byte[] serializedValue = 
InstantiationUtil.serializeObject(states.get(idx));
-                       runningGroupsOffset += serializedValue.length;
-                       serializedGroupValues.add(serializedValue);
-                       ++idx;
+               for(List<? extends Serializable> list : serializables) {
+
+                       long[] currentOffsets = new long[list.size()];
+                       offsets.add(currentOffsets);
+
+                       for (int i = 0; i < list.size(); ++i) {
+                               currentOffsets[i] = runningGroupsOffset;
+                               byte[] serializedValue = 
InstantiationUtil.serializeObject(list.get(i));
+                               serializedGroupValues.add(serializedValue);
+                               runningGroupsOffset += serializedValue.length;
+                       }
                }
 
                //write all generated values in a single byte array, which is 
index by groupOffsetsInFinalByteArray
@@ -2248,15 +2303,7 @@ public class CheckpointCoordinatorTest {
                                        serializedGroupValue.length);
                        runningGroupsOffset += serializedGroupValue.length;
                }
-
-               ByteStreamStateHandle allSerializedStatesHandle = new 
ByteStreamStateHandle(
-                               allSerializedValuesConcatenated);
-               KeyGroupsStateHandle keyGroupsStateHandle = new 
KeyGroupsStateHandle(
-                               keyGroupRangeOffsets,
-                               allSerializedStatesHandle);
-               List<KeyGroupsStateHandle> keyGroupsStateHandleList = new 
ArrayList<>();
-               keyGroupsStateHandleList.add(keyGroupsStateHandle);
-               return keyGroupsStateHandleList;
+               return new Tuple2<>(allSerializedValuesConcatenated, offsets);
        }
 
        public static ChainedStateHandle<StreamStateHandle> 
generateStateForVertex(
@@ -2273,6 +2320,55 @@ public class CheckpointCoordinatorTest {
                return 
ChainedStateHandle.wrapSingleHandle(ByteStreamStateHandle.fromSerializable(value));
        }
 
+       public static ChainedStateHandle<OperatorStateHandle> 
generateChainedPartitionableStateHandle(
+                       JobVertexID jobVertexID,
+                       int index,
+                       int namedStates,
+                       int partitionsPerState) throws IOException {
+
+               Map<String, List<? extends Serializable>> statesListsMap = new 
HashMap<>(namedStates);
+
+               for (int i = 0; i < namedStates; ++i) {
+                       List<Integer> testStatesLists = new 
ArrayList<>(partitionsPerState);
+                       // generate state
+                       Random random = new Random(jobVertexID.hashCode() * 
index + i * namedStates);
+                       for (int j = 0; j < partitionsPerState; ++j) {
+                               int simulatedStateValue = random.nextInt();
+                               testStatesLists.add(simulatedStateValue);
+                       }
+                       statesListsMap.put("state-" + i, testStatesLists);
+               }
+
+               return generateChainedPartitionableStateHandle(statesListsMap);
+       }
+
+       public static ChainedStateHandle<OperatorStateHandle> 
generateChainedPartitionableStateHandle(
+                       Map<String, List<? extends Serializable>> states) 
throws IOException {
+
+               List<List<? extends Serializable>> namedStateSerializables = 
new ArrayList<>(states.size());
+
+               for (Map.Entry<String, List<? extends Serializable>> entry : 
states.entrySet()) {
+                       namedStateSerializables.add(entry.getValue());
+               }
+
+               Tuple2<byte[], List<long[]>> serializationWithOffsets = 
serializeTogetherAndTrackOffsets(namedStateSerializables);
+
+               Map<String, long[]> offsetsMap = new HashMap<>(states.size());
+
+               int idx = 0;
+               for (Map.Entry<String, List<? extends Serializable>> entry : 
states.entrySet()) {
+                       offsetsMap.put(entry.getKey(), 
serializationWithOffsets.f1.get(idx));
+                       ++idx;
+               }
+
+               ByteStreamStateHandle streamStateHandle = new 
ByteStreamStateHandle(
+                               serializationWithOffsets.f0);
+
+               OperatorStateHandle operatorStateHandle =
+                               new OperatorStateHandle(streamStateHandle, 
offsetsMap);
+               return ChainedStateHandle.wrapSingleHandle(operatorStateHandle);
+       }
+
        public static ExecutionJobVertex mockExecutionJobVertex(
                JobVertexID jobVertexID,
                int parallelism,
@@ -2348,16 +2444,24 @@ public class CheckpointCoordinatorTest {
                                        
getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
                        assertEquals(expectNonPartitionedState.get(0), 
actualNonPartitionedState.get(0));
 
+                       ChainedStateHandle<OperatorStateHandle> 
expectedPartitionableState =
+                                       
generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8);
+
+                       List<Collection<OperatorStateHandle>> 
actualPartitionableState = executionJobVertex.
+                                       
getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle();
+
+                       assertEquals(expectedPartitionableState.get(0), 
actualPartitionableState.get(0).iterator().next());
+
                        List<KeyGroupsStateHandle> 
expectPartitionedKeyGroupState = generateKeyGroupState(
                                        jobVertexID,
                                        keyGroupPartitions.get(i));
                        List<KeyGroupsStateHandle> 
actualPartitionedKeyGroupState = executionJobVertex.
                                        
getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles();
-                       comparePartitionedState(expectPartitionedKeyGroupState, 
actualPartitionedKeyGroupState);
+                       
compareKeyPartitionedState(expectPartitionedKeyGroupState, 
actualPartitionedKeyGroupState);
                }
        }
 
-       public static void comparePartitionedState(
+       public static void compareKeyPartitionedState(
                        List<KeyGroupsStateHandle> 
expectPartitionedKeyGroupState,
                        List<KeyGroupsStateHandle> 
actualPartitionedKeyGroupState) throws Exception {
 
@@ -2370,22 +2474,68 @@ public class CheckpointCoordinatorTest {
 
                assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups);
 
-               FSDataInputStream inputStream = 
expectedHeadOpKeyGroupStateHandle.getStateHandle().openInputStream();
-               for(int groupId : 
expectedHeadOpKeyGroupStateHandle.keyGroups()) {
-                       long offset = 
expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
-                       inputStream.seek(offset);
-                       int expectedKeyGroupState = 
InstantiationUtil.deserializeObject(inputStream);
-                       for(KeyGroupsStateHandle oneActualKeyGroupStateHandle : 
actualPartitionedKeyGroupState) {
-                               if 
(oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) {
-                                       long actualOffset = 
oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
-                                       FSDataInputStream actualInputStream = 
oneActualKeyGroupStateHandle.getStateHandle().openInputStream();
-                                       actualInputStream.seek(actualOffset);
-                                       int actualGroupState = 
InstantiationUtil.deserializeObject(actualInputStream);
-
-                                       assertEquals(expectedKeyGroupState, 
actualGroupState);
+               try (FSDataInputStream inputStream = 
expectedHeadOpKeyGroupStateHandle.getStateHandle().openInputStream()) {
+                       for (int groupId : 
expectedHeadOpKeyGroupStateHandle.keyGroups()) {
+                               long offset = 
expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+                               inputStream.seek(offset);
+                               int expectedKeyGroupState = 
InstantiationUtil.deserializeObject(inputStream);
+                               for (KeyGroupsStateHandle 
oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) {
+                                       if 
(oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) {
+                                               long actualOffset = 
oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+                                               try (FSDataInputStream 
actualInputStream =
+                                                                    
oneActualKeyGroupStateHandle.getStateHandle().openInputStream()) {
+                                                       
actualInputStream.seek(actualOffset);
+                                                       int actualGroupState = 
InstantiationUtil.deserializeObject(actualInputStream);
+                                                       
assertEquals(expectedKeyGroupState, actualGroupState);
+                                               }
+                                       }
+                               }
+                       }
+               }
+       }
+
+       public static void comparePartitionableState(
+                       List<ChainedStateHandle<OperatorStateHandle>> expected,
+                       List<List<Collection<OperatorStateHandle>>> actual) 
throws Exception {
+
+               List<String> expectedResult = new ArrayList<>();
+               for (ChainedStateHandle<OperatorStateHandle> chainedStateHandle 
: expected) {
+                       for (int i = 0; i < chainedStateHandle.getLength(); 
++i) {
+                               OperatorStateHandle operatorStateHandle = 
chainedStateHandle.get(i);
+                               try (FSDataInputStream in = 
operatorStateHandle.openInputStream()) {
+                                       for (Map.Entry<String, long[]> entry : 
operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+                                               for (long offset : 
entry.getValue()) {
+                                                       in.seek(offset);
+                                                       Integer state = 
InstantiationUtil.deserializeObject(in);
+                                                       expectedResult.add(i + 
" : " + entry.getKey() + " : " + state);
+                                               }
+                                       }
                                }
                        }
                }
+               Collections.sort(expectedResult);
+
+               List<String> actualResult = new ArrayList<>();
+               for (List<Collection<OperatorStateHandle>> collectionList : 
actual) {
+                       if (collectionList != null) {
+                               for (int i = 0; i < collectionList.size(); ++i) 
{
+                                       Collection<OperatorStateHandle> 
stateHandles = collectionList.get(i);
+                                       for (OperatorStateHandle 
operatorStateHandle : stateHandles) {
+                                               try (FSDataInputStream in = 
operatorStateHandle.openInputStream()) {
+                                                       for (Map.Entry<String, 
long[]> entry : 
operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+                                                               for (long 
offset : entry.getValue()) {
+                                                                       
in.seek(offset);
+                                                                       Integer 
state = InstantiationUtil.deserializeObject(in);
+                                                                       
actualResult.add(i + " : " + entry.getKey() + " : " + state);
+                                                               }
+                                                       }
+                                               }
+                                       }
+                               }
+                       }
+               }
+               Collections.sort(actualResult);
+               Assert.assertEquals(expectedResult, actualResult);
        }
 
        @Test
@@ -2415,4 +2565,117 @@ public class CheckpointCoordinatorTest {
                }
        }
 
+
+       @Test
+       public void testPartitionableStateRepartitioning() {
+               Random r = new Random(42);
+
+               for (int run = 0; run < 10000; ++run) {
+                       int oldParallelism = 1 + r.nextInt(9);
+                       int newParallelism = 1 + r.nextInt(9);
+
+                       int numNamedStates = 1 + r.nextInt(9);
+                       int maxPartitionsPerState = 1 + r.nextInt(9);
+
+                       doTestPartitionableStateRepartitioning(
+                                       r, oldParallelism, newParallelism, 
numNamedStates, maxPartitionsPerState);
+               }
+       }
+
+       private void doTestPartitionableStateRepartitioning(
+                       Random r, int oldParallelism, int newParallelism, int 
numNamedStates, int maxPartitionsPerState) {
+
+               List<OperatorStateHandle> previousParallelOpInstanceStates = 
new ArrayList<>(oldParallelism);
+
+               for (int i = 0; i < oldParallelism; ++i) {
+                       Path fakePath = new Path("/fake-" + i);
+                       Map<String, long[]> namedStatesToOffsets = new 
HashMap<>();
+                       int off = 0;
+                       for (int s = 0; s < numNamedStates; ++s) {
+                               long[] offs = new long[1 + 
r.nextInt(maxPartitionsPerState)];
+                               if (offs.length > 0) {
+                                       for (int o = 0; o < offs.length; ++o) {
+                                               offs[o] = off;
+                                               ++off;
+                                       }
+                                       namedStatesToOffsets.put("State-" + s, 
offs);
+                               }
+                       }
+
+                       previousParallelOpInstanceStates.add(
+                                       new OperatorStateHandle(new 
FileStateHandle(fakePath, -1), namedStatesToOffsets));
+               }
+
+               Map<StreamStateHandle, Map<String, List<Long>>> expected = new 
HashMap<>();
+
+               int expectedTotalPartitions = 0;
+               for (OperatorStateHandle psh : 
previousParallelOpInstanceStates) {
+                       Map<String, long[]> offsMap = 
psh.getStateNameToPartitionOffsets();
+                       Map<String, List<Long>> offsMapWithList = new 
HashMap<>(offsMap.size());
+                       for (Map.Entry<String, long[]> e : offsMap.entrySet()) {
+                               long[] offs = e.getValue();
+                               expectedTotalPartitions += offs.length;
+                               List<Long> offsList = new 
ArrayList<>(offs.length);
+                               for (int i = 0; i < offs.length; ++i) {
+                                       offsList.add(i, offs[i]);
+                               }
+                               offsMapWithList.put(e.getKey(), offsList);
+                       }
+                       expected.put(psh.getDelegateStateHandle(), 
offsMapWithList);
+               }
+
+               OperatorStateRepartitioner repartitioner = 
RoundRobinOperatorStateRepartitioner.INSTANCE;
+
+               List<Collection<OperatorStateHandle>> pshs =
+                               
repartitioner.repartitionState(previousParallelOpInstanceStates, 
newParallelism);
+
+               Map<StreamStateHandle, Map<String, List<Long>>> actual = new 
HashMap<>();
+
+               int minCount = Integer.MAX_VALUE;
+               int maxCount = 0;
+               int actualTotalPartitions = 0;
+               for (int p = 0; p < newParallelism; ++p) {
+                       int partitionCount = 0;
+
+                       Collection<OperatorStateHandle> pshc = pshs.get(p);
+                       for (OperatorStateHandle sh : pshc) {
+                               for (Map.Entry<String, long[]> namedState : 
sh.getStateNameToPartitionOffsets().entrySet()) {
+
+                                       Map<String, List<Long>> x = 
actual.get(sh.getDelegateStateHandle());
+                                       if (x == null) {
+                                               x = new HashMap<>();
+                                               
actual.put(sh.getDelegateStateHandle(), x);
+                                       }
+
+                                       List<Long> actualOffs = 
x.get(namedState.getKey());
+                                       if (actualOffs == null) {
+                                               actualOffs = new ArrayList<>();
+                                               x.put(namedState.getKey(), 
actualOffs);
+                                       }
+                                       long[] add = namedState.getValue();
+                                       for (int i = 0; i < add.length; ++i) {
+                                               actualOffs.add(add[i]);
+                                       }
+
+                                       partitionCount += 
namedState.getValue().length;
+                               }
+                       }
+
+                       minCount = Math.min(minCount, partitionCount);
+                       maxCount = Math.max(maxCount, partitionCount);
+                       actualTotalPartitions += partitionCount;
+               }
+
+               for (Map<String, List<Long>> v : actual.values()) {
+                       for (List<Long> l : v.values()) {
+                               Collections.sort(l);
+                       }
+               }
+
+               int maxLoadDiff = maxCount - minCount;
+               Assert.assertTrue("Difference in partition load is > 1 : " + 
maxLoadDiff, maxLoadDiff <= 1);
+               Assert.assertEquals(expectedTotalPartitions, 
actualTotalPartitions);
+               Assert.assertEquals(expected, actual);
+       }
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index a4896aa..bb78b6a 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -29,14 +29,18 @@ import 
org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.util.SerializableObject;
-
+import org.hamcrest.BaseMatcher;
+import org.hamcrest.Description;
 import org.junit.Test;
 import org.mockito.Mockito;
 
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -112,9 +116,11 @@ public class CheckpointStateRestoreTest {
                        PendingCheckpoint pending = 
coord.getPendingCheckpoints().values().iterator().next();
                        final long checkpointId = pending.getCheckpointId();
 
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, 
serializedState, serializedKeyGroupStates));
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, 
serializedState, serializedKeyGroupStates));
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, 
serializedState, serializedKeyGroupStates));
+                       CheckpointStateHandles checkpointStateHandles = new 
CheckpointStateHandles(serializedState, null, serializedKeyGroupStates);
+
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, 
checkpointStateHandles));
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, 
checkpointStateHandles));
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, 
checkpointStateHandles));
                        coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId));
                        coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId));
 
@@ -125,11 +131,27 @@ public class CheckpointStateRestoreTest {
                        coord.restoreLatestCheckpointedState(map, true, false);
 
                        // verify that each stateful vertex got the state
-                       verify(statefulExec1, 
times(1)).setInitialState(Mockito.eq(serializedState), 
Mockito.<List<KeyGroupsStateHandle>>any());
-                       verify(statefulExec2, 
times(1)).setInitialState(Mockito.eq(serializedState), 
Mockito.<List<KeyGroupsStateHandle>>any());
-                       verify(statefulExec3, 
times(1)).setInitialState(Mockito.eq(serializedState), 
Mockito.<List<KeyGroupsStateHandle>>any());
-                       verify(statelessExec1, 
times(0)).setInitialState(Mockito.<ChainedStateHandle<StreamStateHandle>>any(), 
Mockito.<List<KeyGroupsStateHandle>>any());
-                       verify(statelessExec2, 
times(0)).setInitialState(Mockito.<ChainedStateHandle<StreamStateHandle>>any(), 
Mockito.<List<KeyGroupsStateHandle>>any());
+
+                       BaseMatcher<CheckpointStateHandles> matcher = new 
BaseMatcher<CheckpointStateHandles>() {
+                               @Override
+                               public boolean matches(Object o) {
+                                       if (o instanceof 
CheckpointStateHandles) {
+                                               return 
((CheckpointStateHandles) 
o).getNonPartitionedStateHandles().equals(serializedState);
+                                       }
+                                       return false;
+                               }
+
+                               @Override
+                               public void describeTo(Description description) 
{
+                                       
description.appendValue(serializedState);
+                               }
+                       };
+
+                       verify(statefulExec1, 
times(1)).setInitialState(Mockito.argThat(matcher), 
Mockito.<List<Collection<OperatorStateHandle>>>any());
+                       verify(statefulExec2, 
times(1)).setInitialState(Mockito.argThat(matcher), 
Mockito.<List<Collection<OperatorStateHandle>>>any());
+                       verify(statefulExec3, 
times(1)).setInitialState(Mockito.argThat(matcher), 
Mockito.<List<Collection<OperatorStateHandle>>>any());
+                       verify(statelessExec1, 
times(0)).setInitialState(Mockito.<CheckpointStateHandles>any(), 
Mockito.<List<Collection<OperatorStateHandle>>>any());
+                       verify(statelessExec2, 
times(0)).setInitialState(Mockito.<CheckpointStateHandles>any(), 
Mockito.<List<Collection<OperatorStateHandle>>>any());
                }
                catch (Exception e) {
                        e.printStackTrace();
@@ -193,9 +215,11 @@ public class CheckpointStateRestoreTest {
                        final long checkpointId = pending.getCheckpointId();
 
                        // the difference to the test "testSetState" is that 
one stateful subtask does not report state
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, 
serializedState, serializedKeyGroupStates));
+                       CheckpointStateHandles checkpointStateHandles = new 
CheckpointStateHandles(serializedState, null, serializedKeyGroupStates);
+
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, 
checkpointStateHandles));
                        coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId));
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, 
serializedState, serializedKeyGroupStates));
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, 
checkpointStateHandles));
                        coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId));
                        coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId));
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
index 6182ffd..289f5c3 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
@@ -197,7 +197,7 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
                JobVertexID jvid = new JobVertexID();
 
                Map<JobVertexID, TaskState> taskGroupStates = new HashMap<>();
-               TaskState taskState = new TaskState(jvid, numberOfStates, 
numberOfStates);
+               TaskState taskState = new TaskState(jvid, numberOfStates, 
numberOfStates, 1);
                taskGroupStates.put(jvid, taskState);
 
                for (int i = 0; i < numberOfStates; i++) {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
index fd4e02d..b8126e9 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
@@ -106,7 +106,7 @@ public class PendingCheckpointTest {
                PendingCheckpoint pending = createPendingCheckpoint();
                PendingCheckpointTest.setTaskState(pending, state);
 
-               pending.acknowledgeTask(ATTEMPT_ID, null, null);
+               pending.acknowledgeTask(ATTEMPT_ID, null);
 
                CompletedCheckpoint checkpoint = pending.finalizeCheckpoint();
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java
index 7258545..3701359 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java
@@ -117,7 +117,7 @@ public class PendingSavepointTest {
 
                Future<String> future = pending.getCompletionFuture();
 
-               pending.acknowledgeTask(ATTEMPT_ID, null, null);
+               pending.acknowledgeTask(ATTEMPT_ID, null);
 
                CompletedCheckpoint checkpoint = pending.finalizeCheckpoint();
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
index 6a8d072..9fbe574 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
@@ -186,10 +186,5 @@ public class ZooKeeperCompletedCheckpointStoreITCase 
extends CompletedCheckpoint
                public long getStateSize() throws IOException {
                        return 0;
                }
-
-               @Override
-               public void close() throws IOException {
-                       
-               }
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
index ef10032..c82be18 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
@@ -24,6 +24,7 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.junit.Test;
@@ -32,7 +33,9 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.ThreadLocalRandom;
 
 import static org.junit.Assert.assertEquals;
@@ -67,17 +70,30 @@ public class SavepointV1Test {
                List<TaskState> taskStates = new ArrayList<>(numTaskStates);
 
                for (int i = 0; i < numTaskStates; i++) {
-                       TaskState taskState = new TaskState(new JobVertexID(), 
numSubtaskStates, numSubtaskStates);
+                       TaskState taskState = new TaskState(new JobVertexID(), 
numSubtaskStates, numSubtaskStates, 1);
                        for (int j = 0; j < numSubtaskStates; j++) {
                                StreamStateHandle stateHandle = new 
ByteStreamStateHandle("Hello".getBytes());
                                taskState.putState(i, new SubtaskState(
                                                new 
ChainedStateHandle<>(Collections.singletonList(stateHandle)), 0));
+
+                               stateHandle = new 
ByteStreamStateHandle("Beautiful".getBytes());
+                               Map<String, long[]> offsetsMap = new 
HashMap<>();
+                               offsetsMap.put("A", new long[]{0, 10, 20});
+                               offsetsMap.put("B", new long[]{30, 40, 50});
+
+                               OperatorStateHandle operatorStateHandle =
+                                               new 
OperatorStateHandle(stateHandle, offsetsMap);
+
+                               taskState.putPartitionableState(
+                                               i,
+                                               new 
ChainedStateHandle<OperatorStateHandle>(
+                                                               
Collections.singletonList(operatorStateHandle)));
                        }
 
                        taskState.putKeyedState(
                                        0,
                                        new KeyGroupsStateHandle(
-                                                       new 
KeyGroupRangeOffsets(1,1, new long[] {42}), new 
ByteStreamStateHandle("Hello".getBytes())));
+                                                       new 
KeyGroupRangeOffsets(1,1, new long[] {42}), new 
ByteStreamStateHandle("World".getBytes())));
 
                        taskStates.add(taskState);
                }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
index 504143b..1e95732 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
@@ -319,7 +319,7 @@ public class SimpleCheckpointStatsTrackerTest {
                                JobVertexID operatorId = 
operatorIds[operatorIndex];
                                int parallelism = 
operatorParallelism[operatorIndex];
 
-                               TaskState taskState = new TaskState(operatorId, 
parallelism, maxParallelism);
+                               TaskState taskState = new TaskState(operatorId, 
parallelism, maxParallelism, 1);
 
                                taskGroupStates.put(operatorId, taskState);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
index ef8e3bd..9b12cac 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
@@ -26,6 +26,7 @@ import akka.testkit.JavaTestKit;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.runtime.akka.AkkaUtils;
 import org.apache.flink.runtime.akka.ListeningBehaviour;
 import org.apache.flink.runtime.blob.BlobServer;
@@ -54,9 +55,11 @@ import 
org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService;
 import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService;
 import org.apache.flink.runtime.messages.JobManagerMessages;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.RetrievableStreamStateHandle;
+import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManager;
 import org.apache.flink.runtime.testingUtils.TestingJobManager;
@@ -80,6 +83,7 @@ import scala.concurrent.duration.FiniteDuration;
 
 import java.util.ArrayDeque;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -441,10 +445,15 @@ public class JobManagerHARecoveryTest {
                private int completedCheckpoints = 0;
 
                @Override
-               public void 
setInitialState(ChainedStateHandle<StreamStateHandle> chainedState, 
List<KeyGroupsStateHandle> keyGroupsState) throws Exception {
+               public void setInitialState(
+                       ChainedStateHandle<StreamStateHandle> chainedState,
+                       List<KeyGroupsStateHandle> keyGroupsState,
+                       List<Collection<OperatorStateHandle>> 
partitionableOperatorState) throws Exception {
                        int subtaskIndex = getIndexInSubtaskGroup();
                        if (subtaskIndex < recoveredStates.length) {
-                               recoveredStates[subtaskIndex] = 
InstantiationUtil.deserializeObject(chainedState.get(0).openInputStream());
+                               try (FSDataInputStream in = 
chainedState.get(0).openInputStream()) {
+                                       recoveredStates[subtaskIndex] = 
InstantiationUtil.deserializeObject(in);
+                               }
                        }
                }
 
@@ -456,11 +465,12 @@ public class JobManagerHARecoveryTest {
 
                                RetrievableStreamStateHandle<Long> state = new 
RetrievableStreamStateHandle<Long>(byteStreamStateHandle);
                                ChainedStateHandle<StreamStateHandle> 
chainedStateHandle = new 
ChainedStateHandle<StreamStateHandle>(Collections.singletonList(state));
+                               CheckpointStateHandles checkpointStateHandles =
+                                               new 
CheckpointStateHandles(chainedStateHandle, null, 
Collections.<KeyGroupsStateHandle>emptyList());
 
                                getEnvironment().acknowledgeCheckpoint(
                                                checkpointId,
-                                               chainedStateHandle,
-                                               
Collections.<KeyGroupsStateHandle>emptyList(),
+                                               checkpointStateHandles,
                                                0L, 0L, 0L, 0L);
                                return true;
                        } catch (Exception ex) {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
index 6a6ac64..4873335 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
@@ -23,11 +23,12 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.junit.Test;
 
@@ -65,13 +66,17 @@ public class CheckpointMessagesTest {
 
                        KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42);
 
+                       CheckpointStateHandles checkpointStateHandles =
+                                       new CheckpointStateHandles(
+                                                       
CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()),
+                                                       
CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new 
JobVertexID(), 0, 2, 8),
+                                                       
CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, 
Collections.singletonList(new MyHandle())));
+
                        AcknowledgeCheckpoint withState = new 
AcknowledgeCheckpoint(
                                        new JobID(),
                                        new ExecutionAttemptID(),
                                        87658976143L,
-                                       
CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()),
-                                       
CheckpointCoordinatorTest.generateKeyGroupState(
-                                                       keyGroupRange, 
Collections.singletonList(new MyHandle())));
+                                       checkpointStateHandles);
 
                        testSerializabilityEqualsHashCode(noState);
                        testSerializabilityEqualsHashCode(withState);
@@ -83,7 +88,6 @@ public class CheckpointMessagesTest {
 
        private static void testSerializabilityEqualsHashCode(Serializable o) 
throws IOException {
                Object copy = CommonTestUtils.createCopySerializable(o);
-               System.out.println(o.getClass() +" "+copy.getClass());
                assertEquals(o, copy);
                assertEquals(o.hashCode(), copy.hashCode());
                assertNotNull(o.toString());
@@ -117,9 +121,6 @@ public class CheckpointMessagesTest {
                }
 
                @Override
-               public void close() throws IOException {}
-
-               @Override
                public FSDataInputStream openInputStream() throws IOException {
                        return null;
                }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
index a857d1b..c855230 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
@@ -37,6 +37,7 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
@@ -162,7 +163,7 @@ public class DummyEnvironment implements Environment {
        @Override
        public void acknowledgeCheckpoint(
                        long checkpointId,
-                       ChainedStateHandle<StreamStateHandle> 
chainedStateHandle, List<KeyGroupsStateHandle> keyGroupStateHandles,
+                       CheckpointStateHandles checkpointStateHandles,
                        long synchronousDurationMillis, long 
asynchronousDurationMillis,
                        long bytesBufferedInAlignment, long 
alignmentDurationNanos) {
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
index 75e88eb..c3ed6c0 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
@@ -46,6 +46,7 @@ import 
org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 
@@ -323,7 +324,7 @@ public class MockEnvironment implements Environment {
        @Override
        public void acknowledgeCheckpoint(
                        long checkpointId,
-                       ChainedStateHandle<StreamStateHandle> 
chainedStateHandle, List<KeyGroupsStateHandle> keyGroupStateHandles,
+                       CheckpointStateHandles checkpointStateHandles,
                        long synchronousDurationMillis, long 
asynchronousDurationMillis,
                        long bytesBufferedInAlignment, long 
alignmentDurationNanos) {
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
index 1039568..4279635 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
@@ -32,8 +32,8 @@ import org.apache.flink.runtime.query.netty.KvStateClient;
 import org.apache.flink.runtime.query.netty.KvStateServer;
 import org.apache.flink.runtime.query.netty.UnknownKvStateID;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.heap.HeapValueState;
@@ -246,7 +246,7 @@ public class QueryableStateClientTest {
                MemoryStateBackend backend = new MemoryStateBackend();
                DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
 
-               KeyedStateBackend<Integer> keyedStateBackend = 
backend.createKeyedStateBackend(dummyEnv,
+               AbstractKeyedStateBackend<Integer> keyedStateBackend = 
backend.createKeyedStateBackend(dummyEnv,
                                new JobID(),
                                "test_op",
                                IntSerializer.INSTANCE,

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
index c8fb4bb..0db8b31 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
@@ -41,9 +41,9 @@ import org.apache.flink.runtime.query.KvStateServerAddress;
 import org.apache.flink.runtime.query.netty.message.KvStateRequest;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.KvState;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
@@ -538,7 +538,8 @@ public class KvStateClientTest {
                KvStateRegistry dummyRegistry = new KvStateRegistry();
                DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
                dummyEnv.setKvStateRegistry(dummyRegistry);
-               KeyedStateBackend<Integer> backend = 
abstractBackend.createKeyedStateBackend(
+
+               AbstractKeyedStateBackend<Integer> backend = 
abstractBackend.createKeyedStateBackend(
                                dummyEnv,
                                new JobID(),
                                "test_op",

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
index 7e6d713..ed4a822 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
@@ -38,6 +38,7 @@ import 
org.apache.flink.runtime.query.netty.message.KvStateRequestFailure;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestResult;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateBackend;
@@ -92,7 +93,7 @@ public class KvStateServerHandlerTest {
                AbstractStateBackend abstractBackend = new MemoryStateBackend();
                DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
                dummyEnv.setKvStateRegistry(registry);
-               KeyedStateBackend<Integer> backend = 
abstractBackend.createKeyedStateBackend(
+               AbstractKeyedStateBackend<Integer> backend = 
abstractBackend.createKeyedStateBackend(
                                dummyEnv,
                                new JobID(),
                                "test_op",
@@ -490,7 +491,7 @@ public class KvStateServerHandlerTest {
                AbstractStateBackend abstractBackend = new MemoryStateBackend();
                DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
                dummyEnv.setKvStateRegistry(registry);
-               KeyedStateBackend<Integer> backend = 
abstractBackend.createKeyedStateBackend(
+               AbstractKeyedStateBackend<Integer> backend = 
abstractBackend.createKeyedStateBackend(
                                dummyEnv,
                                new JobID(),
                                "test_op",
@@ -586,7 +587,7 @@ public class KvStateServerHandlerTest {
                AbstractStateBackend abstractBackend = new MemoryStateBackend();
                DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
                dummyEnv.setKvStateRegistry(registry);
-               KeyedStateBackend<Integer> backend = 
abstractBackend.createKeyedStateBackend(
+               AbstractKeyedStateBackend<Integer> backend = 
abstractBackend.createKeyedStateBackend(
                                dummyEnv,
                                new JobID(),
                                "test_op",

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
index e92fb10..b1c4a9f 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
@@ -41,9 +41,9 @@ import org.apache.flink.runtime.query.KvStateServerAddress;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestResult;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
@@ -91,7 +91,7 @@ public class KvStateServerTest {
                        AbstractStateBackend abstractBackend = new 
MemoryStateBackend();
                        DummyEnvironment dummyEnv = new 
DummyEnvironment("test", 1, 0);
                        dummyEnv.setKvStateRegistry(registry);
-                       KeyedStateBackend<Integer> backend = 
abstractBackend.createKeyedStateBackend(
+                       AbstractKeyedStateBackend<Integer> backend = 
abstractBackend.createKeyedStateBackend(
                                        dummyEnv,
                                        new JobID(),
                                        "test_op",

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java
deleted file mode 100644
index e613105..0000000
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java
+++ /dev/null
@@ -1,97 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.state;
-
-import org.junit.Test;
-
-import java.io.Closeable;
-import java.io.IOException;
-
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
-
-public class AbstractCloseableHandleTest {
-
-       @Test
-       public void testRegisterThenClose() throws Exception {
-               Closeable closeable = mock(Closeable.class);
-
-               AbstractCloseableHandle handle = new CloseableHandle();
-               assertFalse(handle.isClosed());
-
-               // no immediate closing
-               handle.registerCloseable(closeable);
-               verify(closeable, times(0)).close();
-               assertFalse(handle.isClosed());
-
-               // close forwarded once
-               handle.close();
-               verify(closeable, times(1)).close();
-               assertTrue(handle.isClosed());
-
-               // no repeated closing
-               handle.close();
-               verify(closeable, times(1)).close();
-               assertTrue(handle.isClosed());
-       }
-
-       @Test
-       public void testCloseThenRegister() throws Exception {
-               Closeable closeable = mock(Closeable.class);
-
-               AbstractCloseableHandle handle = new CloseableHandle();
-               assertFalse(handle.isClosed());
-
-               // close the handle before setting the closeable
-               handle.close();
-               assertTrue(handle.isClosed());
-
-               // immediate closing
-               try {
-                       handle.registerCloseable(closeable);
-                       fail("this should throw an excepion");
-               } catch (IOException e) {
-                       // expected
-                       assertTrue(e.getMessage().contains("closed"));
-               }
-
-               // should still have called "close" on the Closeable
-               verify(closeable, times(1)).close();
-               assertTrue(handle.isClosed());
-
-               // no repeated closing
-               handle.close();
-               verify(closeable, times(1)).close();
-               assertTrue(handle.isClosed());
-       }
-
-       // 
------------------------------------------------------------------------
-
-       private static final class CloseableHandle extends 
AbstractCloseableHandle {
-               private static final long serialVersionUID = 1L;
-
-               @Override
-               public void discardState() {}
-
-               @Override
-               public long getStateSize() {
-                       return 0;
-               }
-       }
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
index bc0b9c3..0b04ebc 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
@@ -20,16 +20,11 @@ package org.apache.flink.runtime.state;
 
 import org.apache.commons.io.FileUtils;
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.typeutils.base.IntSerializer;
-import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.testutils.CommonTestUtils;
-import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
-
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
-
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
@@ -39,9 +34,12 @@ import java.io.IOException;
 import java.io.InputStream;
 import java.net.URI;
 import java.util.Random;
-import java.util.UUID;
 
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> 
{
 
@@ -188,18 +186,21 @@ public class FileStateBackendTest extends 
StateBackendTestBase<FsStateBackend> {
        }
 
        private static void validateBytesInStream(InputStream is, byte[] data) 
throws IOException {
-               byte[] holder = new byte[data.length];
+               try {
+                       byte[] holder = new byte[data.length];
 
-               int pos = 0;
-               int read;
-               while (pos < holder.length && (read = is.read(holder, pos, 
holder.length - pos)) != -1) {
-                       pos += read;
-               }
+                       int pos = 0;
+                       int read;
+                       while (pos < holder.length && (read = is.read(holder, 
pos, holder.length - pos)) != -1) {
+                               pos += read;
+                       }
 
-               assertEquals("not enough data", holder.length, pos);
-               assertEquals("too much data", -1, is.read());
-               assertArrayEquals("wrong data", data, holder);
-               is.close();
+                       assertEquals("not enough data", holder.length, pos);
+                       assertEquals("too much data", -1, is.read());
+                       assertArrayEquals("wrong data", data, holder);
+               } finally {
+                       is.close();
+               }
        }
 
        @Test

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
index 944938b..ac6adff 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
@@ -19,8 +19,6 @@
 package org.apache.flink.runtime.state;
 
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.typeutils.base.IntSerializer;
-import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.junit.Test;
 
@@ -29,7 +27,10 @@ import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.util.HashMap;
 
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 /**
  * Tests for the {@link 
org.apache.flink.runtime.state.memory.MemoryStateBackend}.
@@ -105,10 +106,10 @@ public class MemoryStateBackendTest extends 
StateBackendTestBase<MemoryStateBack
 
                        assertNotNull(handle);
 
-                       ObjectInputStream ois = new 
ObjectInputStream(handle.openInputStream());
-                       assertEquals(state, ois.readObject());
-                       assertTrue(ois.available() <= 0);
-                       ois.close();
+                       try (ObjectInputStream ois = new 
ObjectInputStream(handle.openInputStream())) {
+                               assertEquals(state, ois.readObject());
+                               assertTrue(ois.available() <= 0);
+                       }
                }
                catch (Exception e) {
                        e.printStackTrace();

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
new file mode 100644
index 0000000..56c8987
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.java.typeutils.runtime.JavaSerializer;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.junit.Test;
+
+import java.io.Serializable;
+import java.util.Collections;
+import java.util.Iterator;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+public class OperatorStateBackendTest {
+
+       AbstractStateBackend abstractStateBackend = new 
MemoryStateBackend(1024);
+
+       private OperatorStateBackend createNewOperatorStateBackend() throws 
Exception {
+               return abstractStateBackend.createOperatorStateBackend(null, 
"test-operator");
+       }
+
+       @Test
+       public void testCreateNew() throws Exception {
+               OperatorStateBackend operatorStateBackend = 
createNewOperatorStateBackend();
+               assertNotNull(operatorStateBackend);
+               
assertTrue(operatorStateBackend.getRegisteredStateNames().isEmpty());
+       }
+
+       @Test
+       public void testRegisterStates() throws Exception {
+               OperatorStateBackend operatorStateBackend = 
createNewOperatorStateBackend();
+               ListStateDescriptor<Serializable> stateDescriptor1 = new 
ListStateDescriptor<>("test1", new JavaSerializer<>());
+               ListStateDescriptor<Serializable> stateDescriptor2 = new 
ListStateDescriptor<>("test2", new JavaSerializer<>());
+               ListState<Serializable> listState1 = 
operatorStateBackend.getPartitionableState(stateDescriptor1);
+               assertNotNull(listState1);
+               assertEquals(1, 
operatorStateBackend.getRegisteredStateNames().size());
+               Iterator<Serializable> it = listState1.get().iterator();
+               assertTrue(!it.hasNext());
+               listState1.add(42);
+               listState1.add(4711);
+
+               it = listState1.get().iterator();
+               assertEquals(42, it.next());
+               assertEquals(4711, it.next());
+               assertTrue(!it.hasNext());
+
+               ListState<Serializable> listState2 = 
operatorStateBackend.getPartitionableState(stateDescriptor2);
+               assertNotNull(listState2);
+               assertEquals(2, 
operatorStateBackend.getRegisteredStateNames().size());
+               assertTrue(!it.hasNext());
+               listState2.add(7);
+               listState2.add(13);
+               listState2.add(23);
+
+               it = listState2.get().iterator();
+               assertEquals(7, it.next());
+               assertEquals(13, it.next());
+               assertEquals(23, it.next());
+               assertTrue(!it.hasNext());
+
+               ListState<Serializable> listState1b = 
operatorStateBackend.getPartitionableState(stateDescriptor1);
+               assertNotNull(listState1b);
+               listState1b.add(123);
+               it = listState1b.get().iterator();
+               assertEquals(42, it.next());
+               assertEquals(4711, it.next());
+               assertEquals(123, it.next());
+               assertTrue(!it.hasNext());
+
+               it = listState1.get().iterator();
+               assertEquals(42, it.next());
+               assertEquals(4711, it.next());
+               assertEquals(123, it.next());
+               assertTrue(!it.hasNext());
+
+               it = listState1b.get().iterator();
+               assertEquals(42, it.next());
+               assertEquals(4711, it.next());
+               assertEquals(123, it.next());
+               assertTrue(!it.hasNext());
+       }
+
+       @Test
+       public void testSnapshotRestore() throws Exception {
+               OperatorStateBackend operatorStateBackend = 
createNewOperatorStateBackend();
+               ListStateDescriptor<Serializable> stateDescriptor1 = new 
ListStateDescriptor<>("test1", new JavaSerializer<>());
+               ListStateDescriptor<Serializable> stateDescriptor2 = new 
ListStateDescriptor<>("test2", new JavaSerializer<>());
+               ListState<Serializable> listState1 = 
operatorStateBackend.getPartitionableState(stateDescriptor1);
+               ListState<Serializable> listState2 = 
operatorStateBackend.getPartitionableState(stateDescriptor2);
+
+               listState1.add(42);
+               listState1.add(4711);
+
+               listState2.add(7);
+               listState2.add(13);
+               listState2.add(23);
+
+               CheckpointStreamFactory streamFactory = 
abstractStateBackend.createStreamFactory(new JobID(), "testOperator");
+               OperatorStateHandle stateHandle = 
operatorStateBackend.snapshot(1, 1, streamFactory).get();
+
+               try {
+
+                       operatorStateBackend.dispose();
+
+                       operatorStateBackend = abstractStateBackend.
+                                       restoreOperatorStateBackend(null, 
"testOperator", Collections.singletonList(stateHandle));
+
+                       assertEquals(0, 
operatorStateBackend.getRegisteredStateNames().size());
+
+                       listState1 = 
operatorStateBackend.getPartitionableState(stateDescriptor1);
+                       listState2 = 
operatorStateBackend.getPartitionableState(stateDescriptor2);
+
+                       assertEquals(2, 
operatorStateBackend.getRegisteredStateNames().size());
+
+
+                       Iterator<Serializable> it = listState1.get().iterator();
+                       assertEquals(42, it.next());
+                       assertEquals(4711, it.next());
+                       assertTrue(!it.hasNext());
+
+                       it = listState2.get().iterator();
+                       assertEquals(7, it.next());
+                       assertEquals(13, it.next());
+                       assertEquals(23, it.next());
+                       assertTrue(!it.hasNext());
+
+                       operatorStateBackend.dispose();
+               } finally {
+
+                       stateHandle.discardState();
+               }
+       }
+
+}
\ No newline at end of file

Reply via email to