[FLINK-6034] [checkpoints] Introduce KeyedStateHandle abstraction for the 
snapshots in keyed streams


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/cd552741
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/cd552741
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/cd552741

Branch: refs/heads/table-retraction
Commit: cd5527417a1cae57073a8855c6c3b88c88c780aa
Parents: 89866a5
Author: xiaogang.sxg <[email protected]>
Authored: Thu Mar 23 23:32:15 2017 +0800
Committer: Stefan Richter <[email protected]>
Committed: Tue Mar 28 20:05:28 2017 +0200

----------------------------------------------------------------------
 .../state/RocksDBKeyedStateBackend.java         | 46 ++++++++++----
 .../state/RocksDBAsyncSnapshotTest.java         |  3 +-
 .../state/RocksDBStateBackendTest.java          | 21 ++++---
 .../cep/operator/CEPMigration12to13Test.java    | 14 ++---
 .../apache/flink/migration/MigrationUtil.java   | 10 +--
 .../checkpoint/StateAssignmentOperation.java    | 41 ++++++------
 .../flink/runtime/checkpoint/SubtaskState.java  | 14 ++---
 .../savepoint/SavepointV1Serializer.java        | 42 +++++++------
 .../state/AbstractKeyedStateBackend.java        |  2 +-
 .../runtime/state/KeyGroupsStateHandle.java     | 39 ++++--------
 .../flink/runtime/state/KeyedStateHandle.java   | 40 ++++++++++++
 .../state/StateInitializationContextImpl.java   | 28 ++++++++-
 .../StateSnapshotContextSynchronousImpl.java    | 12 ++--
 .../flink/runtime/state/TaskStateHandles.java   | 16 ++---
 .../state/heap/HeapKeyedStateBackend.java       | 46 ++++++++++----
 .../checkpoint/CheckpointCoordinatorTest.java   | 29 +++++----
 .../checkpoint/CheckpointStateRestoreTest.java  |  3 +-
 .../savepoint/MigrationV0ToV1Test.java          | 14 ++++-
 .../KeyedStateCheckpointOutputStreamTest.java   |  4 +-
 .../runtime/state/StateBackendTestBase.java     | 66 ++++++++++----------
 ...pKeyedStateBackendSnapshotMigrationTest.java |  3 +-
 .../api/operators/AbstractStreamOperator.java   |  7 ++-
 .../api/operators/OperatorSnapshotResult.java   | 18 +++---
 .../runtime/tasks/OperatorStateHandles.java     | 14 ++---
 .../streaming/runtime/tasks/StreamTask.java     | 14 ++---
 .../operators/AbstractStreamOperatorTest.java   | 10 +--
 .../operators/OperatorSnapshotResultTest.java   | 10 +--
 .../StateInitializationContextImplTest.java     |  9 +--
 .../tasks/InterruptSensitiveRestoreTest.java    | 17 ++---
 .../streaming/runtime/tasks/StreamTaskTest.java | 14 ++---
 .../util/AbstractStreamOperatorTestHarness.java | 25 ++++----
 .../KeyedOneInputStreamOperatorTestHarness.java | 17 ++---
 .../KeyedTwoInputStreamOperatorTestHarness.java |  3 +-
 33 files changed, 389 insertions(+), 262 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
 
b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 2ce527f..0407070 100644
--- 
a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ 
b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -40,6 +40,7 @@ import 
org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.migration.MigrationNamespaceSerializerProxy;
 import org.apache.flink.migration.MigrationUtil;
 import org.apache.flink.migration.contrib.streaming.state.RocksDBStateBackend;
+import org.apache.flink.migration.state.MigrationKeyGroupStateHandle;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.io.async.AbstractAsyncIOCallable;
 import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback;
@@ -52,6 +53,7 @@ import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.internal.InternalAggregatingState;
@@ -257,7 +259,7 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
         * @throws Exception
         */
        @Override
-       public RunnableFuture<KeyGroupsStateHandle> snapshot(
+       public RunnableFuture<KeyedStateHandle> snapshot(
                        final long checkpointId,
                        final long timestamp,
                        final CheckpointStreamFactory streamFactory,
@@ -286,8 +288,8 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                }
 
                // implementation of the async IO operation, based on FutureTask
-               AbstractAsyncIOCallable<KeyGroupsStateHandle, 
CheckpointStreamFactory.CheckpointStateOutputStream> ioCallable =
-                               new 
AbstractAsyncIOCallable<KeyGroupsStateHandle, 
CheckpointStreamFactory.CheckpointStateOutputStream>() {
+               AbstractAsyncIOCallable<KeyedStateHandle, 
CheckpointStreamFactory.CheckpointStateOutputStream> ioCallable =
+                               new AbstractAsyncIOCallable<KeyedStateHandle, 
CheckpointStreamFactory.CheckpointStateOutputStream>() {
 
                                        @Override
                                        public 
CheckpointStreamFactory.CheckpointStateOutputStream openIOHandle() throws 
Exception {
@@ -620,7 +622,7 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
        }
 
        @Override
-       public void restore(Collection<KeyGroupsStateHandle> restoreState) 
throws Exception {
+       public void restore(Collection<KeyedStateHandle> restoreState) throws 
Exception {
                LOG.info("Initializing RocksDB keyed state backend from 
snapshot.");
 
                if (LOG.isDebugEnabled()) {
@@ -669,17 +671,23 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                /**
                 * Restores all key-groups data that is referenced by the 
passed state handles.
                 *
-                * @param keyGroupsStateHandles List of all key groups state 
handles that shall be restored.
+                * @param keyedStateHandles List of all key groups state 
handles that shall be restored.
                 * @throws IOException
                 * @throws ClassNotFoundException
                 * @throws RocksDBException
                 */
-               public void doRestore(Collection<KeyGroupsStateHandle> 
keyGroupsStateHandles)
+               public void doRestore(Collection<KeyedStateHandle> 
keyedStateHandles)
                                throws IOException, ClassNotFoundException, 
RocksDBException {
 
-                       for (KeyGroupsStateHandle keyGroupsStateHandle : 
keyGroupsStateHandles) {
-                               if (keyGroupsStateHandle != null) {
-                                       this.currentKeyGroupsStateHandle = 
keyGroupsStateHandle;
+                       for (KeyedStateHandle keyedStateHandle : 
keyedStateHandles) {
+                               if (keyedStateHandle != null) {
+
+                                       if (!(keyedStateHandle instanceof 
KeyGroupsStateHandle)) {
+                                               throw new 
IllegalStateException("Unexpected state handle type, " +
+                                                               "expected: " + 
KeyGroupsStateHandle.class +
+                                                               ", but found: " 
+ keyedStateHandle.getClass());
+                                       }
+                                       this.currentKeyGroupsStateHandle = 
(KeyGroupsStateHandle) keyedStateHandle;
                                        restoreKeyGroupsInStateHandle();
                                }
                        }
@@ -761,6 +769,12 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                private void restoreKVStateData() throws IOException, 
RocksDBException {
                        //for all key-groups in the current state handle...
                        for (Tuple2<Integer, Long> keyGroupOffset : 
currentKeyGroupsStateHandle.getGroupRangeOffsets()) {
+                               int keyGroup = keyGroupOffset.f0;
+
+                               // Check that restored key groups all belong to 
the backend
+                               
Preconditions.checkState(rocksDBKeyedStateBackend.getKeyGroupRange().contains(keyGroup),
+                                       "The key group must belong to the 
backend");
+
                                long offset = keyGroupOffset.f1;
                                //not empty key-group?
                                if (0L != offset) {
@@ -1143,15 +1157,25 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
         * For backwards compatibility, remove again later!
         */
        @Deprecated
-       private void 
restoreOldSavepointKeyedState(Collection<KeyGroupsStateHandle> restoreState) 
throws Exception {
+       private void restoreOldSavepointKeyedState(Collection<KeyedStateHandle> 
restoreState) throws Exception {
 
                if (restoreState.isEmpty()) {
                        return;
                }
 
                Preconditions.checkState(1 == restoreState.size(), "Only one 
element expected here.");
+
+               KeyedStateHandle keyedStateHandle = 
restoreState.iterator().next();
+               if (!(keyedStateHandle instanceof 
MigrationKeyGroupStateHandle)) {
+                       throw new IllegalStateException("Unexpected state 
handle type, " +
+                                       "expected: " + 
MigrationKeyGroupStateHandle.class +
+                                       ", but found: " + 
keyedStateHandle.getClass());
+               }
+
+               MigrationKeyGroupStateHandle keyGroupStateHandle = 
(MigrationKeyGroupStateHandle) keyedStateHandle;
+
                HashMap<String, RocksDBStateBackend.FinalFullyAsyncSnapshot> 
namedStates;
-               try (FSDataInputStream inputStream = 
restoreState.iterator().next().openInputStream()) {
+               try (FSDataInputStream inputStream = 
keyGroupStateHandle.openInputStream()) {
                        namedStates = 
InstantiationUtil.deserializeObject(inputStream, userCodeClassLoader);
                }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
----------------------------------------------------------------------
diff --git 
a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
 
b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
index 90de7a6..ffe2ce2 100644
--- 
a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
+++ 
b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
@@ -42,6 +42,7 @@ import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
@@ -343,7 +344,7 @@ public class RocksDBAsyncSnapshotTest {
                        StringSerializer.INSTANCE,
                        new ValueStateDescriptor<>("foobar", String.class));
 
-               RunnableFuture<KeyGroupsStateHandle> snapshotFuture = 
keyedStateBackend.snapshot(
+               RunnableFuture<KeyedStateHandle> snapshotFuture = 
keyedStateBackend.snapshot(
                        checkpointId, timestamp, checkpointStreamFactory, 
CheckpointOptions.forFullCheckpoint());
 
                try {

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
----------------------------------------------------------------------
diff --git 
a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
 
b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
index 708613b..d95a9b4 100644
--- 
a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
+++ 
b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.StateBackendTestBase;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
@@ -172,7 +173,7 @@ public class RocksDBStateBackendTest extends 
StateBackendTestBase<RocksDBStateBa
        @Test
        public void testRunningSnapshotAfterBackendClosed() throws Exception {
                setupRocksKeyedStateBackend();
-               RunnableFuture<KeyGroupsStateHandle> snapshot = 
keyedStateBackend.snapshot(0L, 0L, testStreamFactory,
+               RunnableFuture<KeyedStateHandle> snapshot = 
keyedStateBackend.snapshot(0L, 0L, testStreamFactory,
                        CheckpointOptions.forFullCheckpoint());
 
                RocksDB spyDB = keyedStateBackend.db;
@@ -210,7 +211,7 @@ public class RocksDBStateBackendTest extends 
StateBackendTestBase<RocksDBStateBa
        @Test
        public void testReleasingSnapshotAfterBackendClosed() throws Exception {
                setupRocksKeyedStateBackend();
-               RunnableFuture<KeyGroupsStateHandle> snapshot = 
keyedStateBackend.snapshot(0L, 0L, testStreamFactory,
+               RunnableFuture<KeyedStateHandle> snapshot = 
keyedStateBackend.snapshot(0L, 0L, testStreamFactory,
                        CheckpointOptions.forFullCheckpoint());
 
                RocksDB spyDB = keyedStateBackend.db;
@@ -239,7 +240,7 @@ public class RocksDBStateBackendTest extends 
StateBackendTestBase<RocksDBStateBa
        @Test
        public void testDismissingSnapshot() throws Exception {
                setupRocksKeyedStateBackend();
-               RunnableFuture<KeyGroupsStateHandle> snapshot = 
keyedStateBackend.snapshot(0L, 0L, testStreamFactory, 
CheckpointOptions.forFullCheckpoint());
+               RunnableFuture<KeyedStateHandle> snapshot = 
keyedStateBackend.snapshot(0L, 0L, testStreamFactory, 
CheckpointOptions.forFullCheckpoint());
                snapshot.cancel(true);
                verifyRocksObjectsReleased();
        }
@@ -247,7 +248,7 @@ public class RocksDBStateBackendTest extends 
StateBackendTestBase<RocksDBStateBa
        @Test
        public void testDismissingSnapshotNotRunnable() throws Exception {
                setupRocksKeyedStateBackend();
-               RunnableFuture<KeyGroupsStateHandle> snapshot = 
keyedStateBackend.snapshot(0L, 0L, testStreamFactory, 
CheckpointOptions.forFullCheckpoint());
+               RunnableFuture<KeyedStateHandle> snapshot = 
keyedStateBackend.snapshot(0L, 0L, testStreamFactory, 
CheckpointOptions.forFullCheckpoint());
                snapshot.cancel(true);
                Thread asyncSnapshotThread = new Thread(snapshot);
                asyncSnapshotThread.start();
@@ -264,7 +265,7 @@ public class RocksDBStateBackendTest extends 
StateBackendTestBase<RocksDBStateBa
        @Test
        public void testCompletingSnapshot() throws Exception {
                setupRocksKeyedStateBackend();
-               RunnableFuture<KeyGroupsStateHandle> snapshot = 
keyedStateBackend.snapshot(0L, 0L, testStreamFactory, 
CheckpointOptions.forFullCheckpoint());
+               RunnableFuture<KeyedStateHandle> snapshot = 
keyedStateBackend.snapshot(0L, 0L, testStreamFactory, 
CheckpointOptions.forFullCheckpoint());
                Thread asyncSnapshotThread = new Thread(snapshot);
                asyncSnapshotThread.start();
                waiter.await(); // wait for snapshot to run
@@ -272,10 +273,10 @@ public class RocksDBStateBackendTest extends 
StateBackendTestBase<RocksDBStateBa
                runStateUpdates();
                blocker.trigger(); // allow checkpointing to start writing
                waiter.await(); // wait for snapshot stream writing to run
-               KeyGroupsStateHandle keyGroupsStateHandle = snapshot.get();
-               assertNotNull(keyGroupsStateHandle);
-               assertTrue(keyGroupsStateHandle.getStateSize() > 0);
-               assertEquals(2, keyGroupsStateHandle.getNumberOfKeyGroups());
+               KeyedStateHandle keyedStateHandle = snapshot.get();
+               assertNotNull(keyedStateHandle);
+               assertTrue(keyedStateHandle.getStateSize() > 0);
+               assertEquals(2, 
keyedStateHandle.getKeyGroupRange().getNumberOfKeyGroups());
                assertTrue(testStreamFactory.getLastCreatedStream().isClosed());
                asyncSnapshotThread.join();
                verifyRocksObjectsReleased();
@@ -284,7 +285,7 @@ public class RocksDBStateBackendTest extends 
StateBackendTestBase<RocksDBStateBa
        @Test
        public void testCancelRunningSnapshot() throws Exception {
                setupRocksKeyedStateBackend();
-               RunnableFuture<KeyGroupsStateHandle> snapshot = 
keyedStateBackend.snapshot(0L, 0L, testStreamFactory, 
CheckpointOptions.forFullCheckpoint());
+               RunnableFuture<KeyedStateHandle> snapshot = 
keyedStateBackend.snapshot(0L, 0L, testStreamFactory, 
CheckpointOptions.forFullCheckpoint());
                Thread asyncSnapshotThread = new Thread(snapshot);
                asyncSnapshotThread.start();
                waiter.await(); // wait for snapshot to run

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java
 
b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java
index f230bbc..dbe4230 100644
--- 
a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java
+++ 
b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java
@@ -26,7 +26,7 @@ import org.apache.flink.cep.nfa.NFA;
 import org.apache.flink.cep.nfa.compiler.NFACompiler;
 import org.apache.flink.cep.pattern.Pattern;
 import org.apache.flink.cep.pattern.conditions.SimpleCondition;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.streaming.api.watermark.Watermark;
@@ -128,8 +128,8 @@ public class CEPMigration12to13Test {
                final OperatorStateHandles snapshot = new OperatorStateHandles(
                        (int) ois.readObject(),
                        (StreamStateHandle) ois.readObject(),
-                       (Collection<KeyGroupsStateHandle>) ois.readObject(),
-                       (Collection<KeyGroupsStateHandle>) ois.readObject(),
+                       (Collection<KeyedStateHandle>) ois.readObject(),
+                       (Collection<KeyedStateHandle>) ois.readObject(),
                        (Collection<OperatorStateHandle>) ois.readObject(),
                        (Collection<OperatorStateHandle>) ois.readObject()
                );
@@ -243,8 +243,8 @@ public class CEPMigration12to13Test {
                final OperatorStateHandles snapshot = new OperatorStateHandles(
                        (int) ois.readObject(),
                        (StreamStateHandle) ois.readObject(),
-                       (Collection<KeyGroupsStateHandle>) ois.readObject(),
-                       (Collection<KeyGroupsStateHandle>) ois.readObject(),
+                       (Collection<KeyedStateHandle>) ois.readObject(),
+                       (Collection<KeyedStateHandle>) ois.readObject(),
                        (Collection<OperatorStateHandle>) ois.readObject(),
                        (Collection<OperatorStateHandle>) ois.readObject()
                );
@@ -363,8 +363,8 @@ public class CEPMigration12to13Test {
                final OperatorStateHandles snapshot = new OperatorStateHandles(
                        (int) ois.readObject(),
                        (StreamStateHandle) ois.readObject(),
-                       (Collection<KeyGroupsStateHandle>) ois.readObject(),
-                       (Collection<KeyGroupsStateHandle>) ois.readObject(),
+                       (Collection<KeyedStateHandle>) ois.readObject(),
+                       (Collection<KeyedStateHandle>) ois.readObject(),
                        (Collection<OperatorStateHandle>) ois.readObject(),
                        (Collection<OperatorStateHandle>) ois.readObject()
                );

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java 
b/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java
index 9427f72..a4e3a2e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java
+++ b/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java
@@ -19,17 +19,17 @@
 package org.apache.flink.migration;
 
 import org.apache.flink.migration.state.MigrationKeyGroupStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 
 import java.util.Collection;
 
 public class MigrationUtil {
 
        @SuppressWarnings("deprecation")
-       public static boolean 
isOldSavepointKeyedState(Collection<KeyGroupsStateHandle> 
keyGroupsStateHandles) {
-               return (keyGroupsStateHandles != null)
-                               && (keyGroupsStateHandles.size() == 1)
-                               && (keyGroupsStateHandles.iterator().next() 
instanceof MigrationKeyGroupStateHandle);
+       public static boolean 
isOldSavepointKeyedState(Collection<KeyedStateHandle> keyedStateHandles) {
+               return (keyedStateHandles != null)
+                               && (keyedStateHandles.size() == 1)
+                               && (keyedStateHandles.iterator().next() 
instanceof MigrationKeyGroupStateHandle);
        }
 
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index 3fda430..ac70e1a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -25,6 +25,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
@@ -160,8 +161,8 @@ public class StateAssignmentOperation {
                @SuppressWarnings("unchecked")
                List<OperatorStateHandle>[] parallelOpStatesStream = new 
List[chainLength];
 
-               List<KeyGroupsStateHandle> parallelKeyedStatesBackend = new 
ArrayList<>(oldParallelism);
-               List<KeyGroupsStateHandle> parallelKeyedStateStream = new 
ArrayList<>(oldParallelism);
+               List<KeyedStateHandle> parallelKeyedStatesBackend = new 
ArrayList<>(oldParallelism);
+               List<KeyedStateHandle> parallelKeyedStateStream = new 
ArrayList<>(oldParallelism);
 
                for (int p = 0; p < oldParallelism; ++p) {
                        SubtaskState subtaskState = taskState.getState(p);
@@ -173,12 +174,12 @@ public class StateAssignmentOperation {
                                collectParallelStatesByChainOperator(
                                                parallelOpStatesStream, 
subtaskState.getRawOperatorState());
 
-                               KeyGroupsStateHandle keyedStateBackend = 
subtaskState.getManagedKeyedState();
+                               KeyedStateHandle keyedStateBackend = 
subtaskState.getManagedKeyedState();
                                if (null != keyedStateBackend) {
                                        
parallelKeyedStatesBackend.add(keyedStateBackend);
                                }
 
-                               KeyGroupsStateHandle keyedStateStream = 
subtaskState.getRawKeyedState();
+                               KeyedStateHandle keyedStateStream = 
subtaskState.getRawKeyedState();
                                if (null != keyedStateStream) {
                                        
parallelKeyedStateStream.add(keyedStateStream);
                                }
@@ -252,13 +253,13 @@ public class StateAssignmentOperation {
                                        .getTaskVertices()[subTaskIdx]
                                        .getCurrentExecutionAttempt();
 
-                       List<KeyGroupsStateHandle> newKeyedStatesBackend;
-                       List<KeyGroupsStateHandle> newKeyedStateStream;
+                       List<KeyedStateHandle> newKeyedStatesBackend;
+                       List<KeyedStateHandle> newKeyedStateStream;
                        if (oldParallelism == newParallelism) {
                                SubtaskState subtaskState = 
taskState.getState(subTaskIdx);
                                if (subtaskState != null) {
-                                       KeyGroupsStateHandle 
oldKeyedStatesBackend = subtaskState.getManagedKeyedState();
-                                       KeyGroupsStateHandle 
oldKeyedStatesStream = subtaskState.getRawKeyedState();
+                                       KeyedStateHandle oldKeyedStatesBackend 
= subtaskState.getManagedKeyedState();
+                                       KeyedStateHandle oldKeyedStatesStream = 
subtaskState.getRawKeyedState();
                                        newKeyedStatesBackend = 
oldKeyedStatesBackend != null ? Collections.singletonList(
                                                        oldKeyedStatesBackend) 
: null;
                                        newKeyedStateStream = 
oldKeyedStatesStream != null ? Collections.singletonList(
@@ -269,8 +270,8 @@ public class StateAssignmentOperation {
                                }
                        } else {
                                KeyGroupRange subtaskKeyGroupIds = 
keyGroupPartitions.get(subTaskIdx);
-                               newKeyedStatesBackend = 
getKeyGroupsStateHandles(parallelKeyedStatesBackend, subtaskKeyGroupIds);
-                               newKeyedStateStream = 
getKeyGroupsStateHandles(parallelKeyedStateStream, subtaskKeyGroupIds);
+                               newKeyedStatesBackend = 
getKeyedStateHandles(parallelKeyedStatesBackend, subtaskKeyGroupIds);
+                               newKeyedStateStream = 
getKeyedStateHandles(parallelKeyedStateStream, subtaskKeyGroupIds);
                        }
 
                        TaskStateHandles taskStateHandles = new 
TaskStateHandles(
@@ -290,19 +291,21 @@ public class StateAssignmentOperation {
         * <p>
         * <p>This is publicly visible to be used in tests.
         */
-       public static List<KeyGroupsStateHandle> getKeyGroupsStateHandles(
-                       Collection<KeyGroupsStateHandle> allKeyGroupsHandles,
-                       KeyGroupRange subtaskKeyGroupIds) {
+       public static List<KeyedStateHandle> getKeyedStateHandles(
+                       Collection<? extends KeyedStateHandle> 
keyedStateHandles,
+                       KeyGroupRange subtaskKeyGroupRange) {
 
-               List<KeyGroupsStateHandle> subtaskKeyGroupStates = new 
ArrayList<>();
+               List<KeyedStateHandle> subtaskKeyedStateHandles = new 
ArrayList<>();
 
-               for (KeyGroupsStateHandle storedKeyGroup : allKeyGroupsHandles) 
{
-                       KeyGroupsStateHandle intersection = 
storedKeyGroup.getKeyGroupIntersection(subtaskKeyGroupIds);
-                       if (intersection.getNumberOfKeyGroups() > 0) {
-                               subtaskKeyGroupStates.add(intersection);
+               for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
+                       KeyedStateHandle intersectedKeyedStateHandle = 
keyedStateHandle.getIntersection(subtaskKeyGroupRange);
+
+                       if (intersectedKeyedStateHandle != null) {
+                               
subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
                        }
                }
-               return subtaskKeyGroupStates;
+
+               return subtaskKeyedStateHandles;
        }
 
        /**

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
index 1393e32..9e195b1 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
@@ -19,7 +19,7 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StateUtil;
@@ -56,12 +56,12 @@ public class SubtaskState implements StateObject {
        /**
         * Snapshot from {@link 
org.apache.flink.runtime.state.KeyedStateBackend}.
         */
-       private final KeyGroupsStateHandle managedKeyedState;
+       private final KeyedStateHandle managedKeyedState;
 
        /**
         * Snapshot written using {@link 
org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream}.
         */
-       private final KeyGroupsStateHandle rawKeyedState;
+       private final KeyedStateHandle rawKeyedState;
 
        /**
         * The state size. This is also part of the deserialized state handle.
@@ -74,8 +74,8 @@ public class SubtaskState implements StateObject {
                        ChainedStateHandle<StreamStateHandle> 
legacyOperatorState,
                        ChainedStateHandle<OperatorStateHandle> 
managedOperatorState,
                        ChainedStateHandle<OperatorStateHandle> 
rawOperatorState,
-                       KeyGroupsStateHandle managedKeyedState,
-                       KeyGroupsStateHandle rawKeyedState) {
+                       KeyedStateHandle managedKeyedState,
+                       KeyedStateHandle rawKeyedState) {
 
                this.legacyOperatorState = checkNotNull(legacyOperatorState, 
"State");
                this.managedOperatorState = managedOperatorState;
@@ -114,11 +114,11 @@ public class SubtaskState implements StateObject {
                return rawOperatorState;
        }
 
-       public KeyGroupsStateHandle getManagedKeyedState() {
+       public KeyedStateHandle getManagedKeyedState() {
                return managedKeyedState;
        }
 
-       public KeyGroupsStateHandle getRawKeyedState() {
+       public KeyedStateHandle getRawKeyedState() {
                return rawKeyedState;
        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
index ba1949a..44461d8 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
@@ -155,11 +156,11 @@ class SavepointV1Serializer implements 
SavepointSerializer<SavepointV1> {
                        serializeOperatorStateHandle(stateHandle, dos);
                }
 
-               KeyGroupsStateHandle keyedStateBackend = 
subtaskState.getManagedKeyedState();
-               serializeKeyGroupStateHandle(keyedStateBackend, dos);
+               KeyedStateHandle keyedStateBackend = 
subtaskState.getManagedKeyedState();
+               serializeKeyedStateHandle(keyedStateBackend, dos);
 
-               KeyGroupsStateHandle keyedStateStream = 
subtaskState.getRawKeyedState();
-               serializeKeyGroupStateHandle(keyedStateStream, dos);
+               KeyedStateHandle keyedStateStream = 
subtaskState.getRawKeyedState();
+               serializeKeyedStateHandle(keyedStateStream, dos);
        }
 
        private static SubtaskState deserializeSubtaskState(DataInputStream 
dis) throws IOException {
@@ -188,9 +189,9 @@ class SavepointV1Serializer implements 
SavepointSerializer<SavepointV1> {
                        operatorStateStream.add(streamStateHandle);
                }
 
-               KeyGroupsStateHandle keyedStateBackend = 
deserializeKeyGroupStateHandle(dis);
+               KeyedStateHandle keyedStateBackend = 
deserializeKeyedStateHandle(dis);
 
-               KeyGroupsStateHandle keyedStateStream = 
deserializeKeyGroupStateHandle(dis);
+               KeyedStateHandle keyedStateStream = 
deserializeKeyedStateHandle(dis);
 
                ChainedStateHandle<StreamStateHandle> 
nonPartitionableStateChain =
                                new ChainedStateHandle<>(nonPartitionableState);
@@ -209,23 +210,27 @@ class SavepointV1Serializer implements 
SavepointSerializer<SavepointV1> {
                                keyedStateStream);
        }
 
-       private static void serializeKeyGroupStateHandle(
-                       KeyGroupsStateHandle stateHandle, DataOutputStream dos) 
throws IOException {
+       private static void serializeKeyedStateHandle(
+                       KeyedStateHandle stateHandle, DataOutputStream dos) 
throws IOException {
+
+               if (stateHandle == null) {
+                       dos.writeByte(NULL_HANDLE);
+               } else if (stateHandle instanceof KeyGroupsStateHandle) {
+                       KeyGroupsStateHandle keyGroupsStateHandle = 
(KeyGroupsStateHandle) stateHandle;
 
-               if (stateHandle != null) {
                        dos.writeByte(KEY_GROUPS_HANDLE);
-                       
dos.writeInt(stateHandle.getGroupRangeOffsets().getKeyGroupRange().getStartKeyGroup());
-                       dos.writeInt(stateHandle.getNumberOfKeyGroups());
-                       for (int keyGroup : stateHandle.keyGroups()) {
-                               
dos.writeLong(stateHandle.getOffsetForKeyGroup(keyGroup));
+                       
dos.writeInt(keyGroupsStateHandle.getKeyGroupRange().getStartKeyGroup());
+                       
dos.writeInt(keyGroupsStateHandle.getKeyGroupRange().getNumberOfKeyGroups());
+                       for (int keyGroup : 
keyGroupsStateHandle.getKeyGroupRange()) {
+                               
dos.writeLong(keyGroupsStateHandle.getOffsetForKeyGroup(keyGroup));
                        }
-                       
serializeStreamStateHandle(stateHandle.getDelegateStateHandle(), dos);
+                       
serializeStreamStateHandle(keyGroupsStateHandle.getDelegateStateHandle(), dos);
                } else {
-                       dos.writeByte(NULL_HANDLE);
+                       throw new IllegalStateException("Unknown 
KeyedStateHandle type: " + stateHandle.getClass());
                }
        }
 
-       private static KeyGroupsStateHandle 
deserializeKeyGroupStateHandle(DataInputStream dis) throws IOException {
+       private static KeyedStateHandle 
deserializeKeyedStateHandle(DataInputStream dis) throws IOException {
                final int type = dis.readByte();
                if (NULL_HANDLE == type) {
                        return null;
@@ -237,11 +242,12 @@ class SavepointV1Serializer implements 
SavepointSerializer<SavepointV1> {
                        for (int i = 0; i < numKeyGroups; ++i) {
                                offsets[i] = dis.readLong();
                        }
-                       KeyGroupRangeOffsets keyGroupRangeOffsets = new 
KeyGroupRangeOffsets(keyGroupRange, offsets);
+                       KeyGroupRangeOffsets keyGroupRangeOffsets = new 
KeyGroupRangeOffsets(
+                               keyGroupRange, offsets);
                        StreamStateHandle stateHandle = 
deserializeStreamStateHandle(dis);
                        return new KeyGroupsStateHandle(keyGroupRangeOffsets, 
stateHandle);
                } else {
-                       throw new IllegalStateException("Reading invalid 
KeyGroupsStateHandle, type: " + type);
+                       throw new IllegalStateException("Reading invalid 
KeyedStateHandle, type: " + type);
                }
        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
index e6e7b23..e86f1f8 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
@@ -61,7 +61,7 @@ import static 
org.apache.flink.util.Preconditions.checkNotNull;
  * @param <K> Type of the key by which state is keyed.
  */
 public abstract class AbstractKeyedStateBackend<K>
-               implements KeyedStateBackend<K>, 
Snapshotable<KeyGroupsStateHandle>, Closeable {
+               implements KeyedStateBackend<K>, 
Snapshotable<KeyedStateHandle>, Closeable {
 
        /** {@link TypeSerializer} for our key. */
        protected final TypeSerializer<K> keySerializer;

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
index b454e42..bad7fd4 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
@@ -29,7 +29,7 @@ import java.io.IOException;
  * consists of a range of key group snapshots. A key group is subset of the 
available
  * key space. The key groups are identified by their key group indices.
  */
-public class KeyGroupsStateHandle implements StreamStateHandle {
+public class KeyGroupsStateHandle implements StreamStateHandle, 
KeyedStateHandle {
 
        private static final long serialVersionUID = -8070326169926626355L;
 
@@ -54,20 +54,18 @@ public class KeyGroupsStateHandle implements 
StreamStateHandle {
 
        /**
         *
-        * @return iterable over the key-group range for the key-group state 
referenced by this handle
+        * @return the internal key-group range to offsets metadata
         */
-       public Iterable<Integer> keyGroups() {
-               return groupRangeOffsets.getKeyGroupRange();
+       public KeyGroupRangeOffsets getGroupRangeOffsets() {
+               return groupRangeOffsets;
        }
 
-
        /**
         *
-        * @param keyGroupId the id of a key-group
-        * @return true if the provided key-group id is contained in the 
key-group range of this handle
+        * @return The handle to the actual states
         */
-       public boolean containsKeyGroup(int keyGroupId) {
-               return 
groupRangeOffsets.getKeyGroupRange().contains(keyGroupId);
+       public StreamStateHandle getDelegateStateHandle() {
+               return stateHandle;
        }
 
        /**
@@ -85,24 +83,13 @@ public class KeyGroupsStateHandle implements 
StreamStateHandle {
         * @return key-group state over a range that is the intersection 
between this handle's key-group range and the
         *          provided key-group range.
         */
-       public KeyGroupsStateHandle getKeyGroupIntersection(KeyGroupRange 
keyGroupRange) {
+       public KeyGroupsStateHandle getIntersection(KeyGroupRange 
keyGroupRange) {
                return new 
KeyGroupsStateHandle(groupRangeOffsets.getIntersection(keyGroupRange), 
stateHandle);
        }
 
-       /**
-        *
-        * @return the internal key-group range to offsets metadata
-        */
-       public KeyGroupRangeOffsets getGroupRangeOffsets() {
-               return groupRangeOffsets;
-       }
-
-       /**
-        *
-        * @return number of key-groups in the key-group range of this handle
-        */
-       public int getNumberOfKeyGroups() {
-               return 
groupRangeOffsets.getKeyGroupRange().getNumberOfKeyGroups();
+       @Override
+       public KeyGroupRange getKeyGroupRange() {
+               return groupRangeOffsets.getKeyGroupRange();
        }
 
        @Override
@@ -120,10 +107,6 @@ public class KeyGroupsStateHandle implements 
StreamStateHandle {
                return stateHandle.openInputStream();
        }
 
-       public StreamStateHandle getDelegateStateHandle() {
-               return stateHandle;
-       }
-
        @Override
        public boolean equals(Object o) {
                if (this == o) {

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java
new file mode 100644
index 0000000..dc9c97d
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+/**
+ * Base for the handles of the checkpointed states in keyed streams. When
+ * recovering from failures, the handle will be passed to all tasks whose key
+ * group ranges overlap with it.
+ */
+public interface KeyedStateHandle extends StateObject {
+
+       /**
+        * Returns the range of the key groups contained in the state.
+        */
+       KeyGroupRange getKeyGroupRange();
+
+       /**
+        * Returns a state over a range that is the intersection between this
+        * handle's key-group range and the provided key-group range.
+        *
+        * @param keyGroupRange The key group range to intersect with
+        */
+       KeyedStateHandle getIntersection(KeyGroupRange keyGroupRange);
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
index 886d214..d82af72 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
@@ -27,9 +27,11 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
+import java.util.List;
 import java.util.NoSuchElementException;
 
 /**
@@ -55,7 +57,7 @@ public class StateInitializationContextImpl implements 
StateInitializationContex
                        boolean restored,
                        OperatorStateStore operatorStateStore,
                        KeyedStateStore keyedStateStore,
-                       Collection<KeyGroupsStateHandle> keyGroupsStateHandles,
+                       Collection<KeyedStateHandle> keyedStateHandles,
                        Collection<OperatorStateHandle> operatorStateHandles,
                        CloseableRegistry closableRegistry) {
 
@@ -64,7 +66,7 @@ public class StateInitializationContextImpl implements 
StateInitializationContex
                this.operatorStateStore = operatorStateStore;
                this.keyedStateStore = keyedStateStore;
                this.operatorStateHandles = operatorStateHandles;
-               this.keyGroupsStateHandles = keyGroupsStateHandles;
+               this.keyGroupsStateHandles = transform(keyedStateHandles);
 
                this.keyedStateIterable = keyGroupsStateHandles == null ?
                                null
@@ -136,6 +138,26 @@ public class StateInitializationContextImpl implements 
StateInitializationContex
                IOUtils.closeQuietly(closableRegistry);
        }
 
+       private static Collection<KeyGroupsStateHandle> 
transform(Collection<KeyedStateHandle> keyedStateHandles) {
+               if (keyedStateHandles == null) {
+                       return null;
+               }
+
+               List<KeyGroupsStateHandle> keyGroupsStateHandles = new 
ArrayList<>();
+
+               for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
+                       if (! (keyedStateHandle instanceof 
KeyGroupsStateHandle)) {
+                               throw new IllegalStateException("Unexpected 
state handle type, " +
+                                       "expected: " + 
KeyGroupsStateHandle.class +
+                                       ", but found: " + 
keyedStateHandle.getClass() + ".");
+                       }
+
+                       keyGroupsStateHandles.add((KeyGroupsStateHandle) 
keyedStateHandle);
+               }
+
+               return keyGroupsStateHandles;
+       }
+
        private static class KeyGroupStreamIterator
                        extends 
AbstractStateStreamIterator<KeyGroupStatePartitionStreamProvider, 
KeyGroupsStateHandle> {
 
@@ -159,7 +181,7 @@ public class StateInitializationContextImpl implements 
StateInitializationContex
 
                        while (stateHandleIterator.hasNext()) {
                                currentStateHandle = stateHandleIterator.next();
-                               if (currentStateHandle.getNumberOfKeyGroups() > 
0) {
+                               if 
(currentStateHandle.getKeyGroupRange().getNumberOfKeyGroups() > 0) {
                                        currentOffsetsIterator = 
currentStateHandle.getGroupRangeOffsets().iterator();
 
                                        return true;

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
index 96edccb..5db0138 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
@@ -109,15 +109,17 @@ public class StateSnapshotContextSynchronousImpl 
implements StateSnapshotContext
                return operatorStateCheckpointOutputStream;
        }
 
-       public RunnableFuture<KeyGroupsStateHandle> getKeyedStateStreamFuture() 
throws IOException {
-               return 
closeAndUnregisterStreamToObtainStateHandle(keyedStateCheckpointOutputStream);
+       public RunnableFuture<KeyedStateHandle> getKeyedStateStreamFuture() 
throws IOException {
+               KeyGroupsStateHandle keyGroupsStateHandle = 
closeAndUnregisterStreamToObtainStateHandle(keyedStateCheckpointOutputStream);
+               return new DoneFuture<KeyedStateHandle>(keyGroupsStateHandle);
        }
 
        public RunnableFuture<OperatorStateHandle> 
getOperatorStateStreamFuture() throws IOException {
-               return 
closeAndUnregisterStreamToObtainStateHandle(operatorStateCheckpointOutputStream);
+               OperatorStateHandle operatorStateHandle = 
closeAndUnregisterStreamToObtainStateHandle(operatorStateCheckpointOutputStream);
+               return new DoneFuture<>(operatorStateHandle);
        }
 
-       private <T extends StreamStateHandle> RunnableFuture<T> 
closeAndUnregisterStreamToObtainStateHandle(
+       private <T extends StreamStateHandle> T 
closeAndUnregisterStreamToObtainStateHandle(
                        NonClosingCheckpointOutputStream<T> stream) throws 
IOException {
                if (null == stream) {
                        return null;
@@ -126,7 +128,7 @@ public class StateSnapshotContextSynchronousImpl implements 
StateSnapshotContext
                closableRegistry.unregisterClosable(stream.getDelegate());
 
                // for now we only support synchronous writing
-               return new DoneFuture<>(stream.closeAndGetHandle());
+               return stream.closeAndGetHandle();
        }
 
        private <T extends StreamStateHandle> void 
closeAndUnregisterStream(NonClosingCheckpointOutputStream<T> stream) throws 
IOException {

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
index 417a9dd..450413a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
@@ -40,10 +40,10 @@ public class TaskStateHandles implements Serializable {
        private final ChainedStateHandle<StreamStateHandle> legacyOperatorState;
 
        /** Collection of handles which represent the managed keyed state of 
the head operator */
-       private final Collection<KeyGroupsStateHandle> managedKeyedState;
+       private final Collection<KeyedStateHandle> managedKeyedState;
 
        /** Collection of handles which represent the raw/streamed keyed state 
of the head operator */
-       private final Collection<KeyGroupsStateHandle> rawKeyedState;
+       private final Collection<KeyedStateHandle> rawKeyedState;
 
        /** Outer list represents the operator chain, each collection holds 
handles for managed state of a single operator */
        private final List<Collection<OperatorStateHandle>> 
managedOperatorState;
@@ -67,8 +67,8 @@ public class TaskStateHandles implements Serializable {
                        ChainedStateHandle<StreamStateHandle> 
legacyOperatorState,
                        List<Collection<OperatorStateHandle>> 
managedOperatorState,
                        List<Collection<OperatorStateHandle>> rawOperatorState,
-                       Collection<KeyGroupsStateHandle> managedKeyedState,
-                       Collection<KeyGroupsStateHandle> rawKeyedState) {
+                       Collection<KeyedStateHandle> managedKeyedState,
+                       Collection<KeyedStateHandle> rawKeyedState) {
 
                this.legacyOperatorState = legacyOperatorState;
                this.managedKeyedState = managedKeyedState;
@@ -82,11 +82,11 @@ public class TaskStateHandles implements Serializable {
                return legacyOperatorState;
        }
 
-       public Collection<KeyGroupsStateHandle> getManagedKeyedState() {
+       public Collection<KeyedStateHandle> getManagedKeyedState() {
                return managedKeyedState;
        }
 
-       public Collection<KeyGroupsStateHandle> getRawKeyedState() {
+       public Collection<KeyedStateHandle> getRawKeyedState() {
                return rawKeyedState;
        }
 
@@ -110,8 +110,8 @@ public class TaskStateHandles implements Serializable {
                return out;
        }
 
-       private static List<KeyGroupsStateHandle> 
transform(KeyGroupsStateHandle in) {
-               return in == null ? 
Collections.<KeyGroupsStateHandle>emptyList() : Collections.singletonList(in);
+       private static <T> List<T> transform(T in) {
+               return in == null ? Collections.<T>emptyList() : 
Collections.singletonList(in);
        }
 
        @Override

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 46ec5c2..a332d7d 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -39,6 +39,7 @@ import 
org.apache.flink.migration.runtime.state.KvStateSnapshot;
 import 
org.apache.flink.migration.runtime.state.memory.MigrationRestoreSnapshot;
 import org.apache.flink.runtime.io.async.AbstractAsyncIOCallable;
 import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback;
+import org.apache.flink.migration.state.MigrationKeyGroupStateHandle;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
@@ -50,6 +51,7 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.internal.InternalAggregatingState;
@@ -223,7 +225,7 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
 
        @Override
        @SuppressWarnings("unchecked")
-       public  RunnableFuture<KeyGroupsStateHandle> snapshot(
+       public  RunnableFuture<KeyedStateHandle> snapshot(
                        final long checkpointId,
                        final long timestamp,
                        final CheckpointStreamFactory streamFactory,
@@ -267,8 +269,8 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                //--------------------------------------------------- this 
becomes the end of sync part
 
                // implementation of the async IO operation, based on FutureTask
-               final AbstractAsyncIOCallable<KeyGroupsStateHandle, 
CheckpointStreamFactory.CheckpointStateOutputStream> ioCallable =
-                               new 
AbstractAsyncIOCallable<KeyGroupsStateHandle, 
CheckpointStreamFactory.CheckpointStateOutputStream>() {
+               final AbstractAsyncIOCallable<KeyedStateHandle, 
CheckpointStreamFactory.CheckpointStateOutputStream> ioCallable =
+                               new AbstractAsyncIOCallable<KeyedStateHandle, 
CheckpointStreamFactory.CheckpointStateOutputStream>() {
 
                                        AtomicBoolean open = new 
AtomicBoolean(false);
 
@@ -340,7 +342,7 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                                        }
                                };
 
-               AsyncStoppableTaskWithCallback<KeyGroupsStateHandle> task = 
AsyncStoppableTaskWithCallback.from(ioCallable);
+               AsyncStoppableTaskWithCallback<KeyedStateHandle> task = 
AsyncStoppableTaskWithCallback.from(ioCallable);
 
                if (!asynchronousSnapshots) {
                        task.run();
@@ -354,7 +356,7 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
 
        @SuppressWarnings("deprecation")
        @Override
-       public void restore(Collection<KeyGroupsStateHandle> restoredState) 
throws Exception {
+       public void restore(Collection<KeyedStateHandle> restoredState) throws 
Exception {
                LOG.info("Initializing heap keyed state backend from 
snapshot.");
 
                if (LOG.isDebugEnabled()) {
@@ -369,19 +371,26 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
        }
 
        @SuppressWarnings({"unchecked"})
-       private void restorePartitionedState(Collection<KeyGroupsStateHandle> 
state) throws Exception {
+       private void restorePartitionedState(Collection<KeyedStateHandle> 
state) throws Exception {
 
                final Map<Integer, String> kvStatesById = new HashMap<>();
                int numRegisteredKvStates = 0;
                stateTables.clear();
 
-               for (KeyGroupsStateHandle keyGroupsHandle : state) {
+               for (KeyedStateHandle keyedStateHandle : state) {
 
-                       if (keyGroupsHandle == null) {
+                       if (keyedStateHandle == null) {
                                continue;
                        }
 
-                       FSDataInputStream fsDataInputStream = 
keyGroupsHandle.openInputStream();
+                       if (!(keyedStateHandle instanceof 
KeyGroupsStateHandle)) {
+                               throw new IllegalStateException("Unexpected 
state handle type, " +
+                                               "expected: " + 
KeyGroupsStateHandle.class +
+                                               ", but found: " + 
keyedStateHandle.getClass());
+                       }
+
+                       KeyGroupsStateHandle keyGroupsStateHandle = 
(KeyGroupsStateHandle) keyedStateHandle;
+                       FSDataInputStream fsDataInputStream = 
keyGroupsStateHandle.openInputStream();
                        
cancelStreamRegistry.registerClosable(fsDataInputStream);
 
                        try {
@@ -412,9 +421,13 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                                        }
                                }
 
-                               for (Tuple2<Integer, Long> groupOffset : 
keyGroupsHandle.getGroupRangeOffsets()) {
+                               for (Tuple2<Integer, Long> groupOffset : 
keyGroupsStateHandle.getGroupRangeOffsets()) {
                                        int keyGroupIndex = groupOffset.f0;
                                        long offset = groupOffset.f1;
+
+                                       // Check that restored key groups all 
belong to the backend.
+                                       
Preconditions.checkState(keyGroupRange.contains(keyGroupIndex), "The key group 
must belong to the backend.");
+
                                        fsDataInputStream.seek(offset);
 
                                        int writtenKeyGroupIndex = 
inView.readInt();
@@ -449,7 +462,7 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
        @SuppressWarnings({"unchecked", "rawtypes", "DeprecatedIsStillUsed"})
        @Deprecated
        private void restoreOldSavepointKeyedState(
-                       Collection<KeyGroupsStateHandle> stateHandles) throws 
IOException, ClassNotFoundException {
+                       Collection<KeyedStateHandle> stateHandles) throws 
IOException, ClassNotFoundException {
 
                if (stateHandles.isEmpty()) {
                        return;
@@ -457,8 +470,17 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
 
                Preconditions.checkState(1 == stateHandles.size(), "Only one 
element expected here.");
 
+               KeyedStateHandle keyedStateHandle = 
stateHandles.iterator().next();
+               if (!(keyedStateHandle instanceof 
MigrationKeyGroupStateHandle)) {
+                       throw new IllegalStateException("Unexpected state 
handle type, " +
+                                       "expected: " + 
MigrationKeyGroupStateHandle.class +
+                                       ", but found " + 
keyedStateHandle.getClass());
+               }
+
+               MigrationKeyGroupStateHandle keyGroupStateHandle = 
(MigrationKeyGroupStateHandle) keyedStateHandle;
+
                HashMap<String, KvStateSnapshot<K, ?, ?, ?>> namedStates;
-               try (FSDataInputStream inputStream = 
stateHandles.iterator().next().openInputStream()) {
+               try (FSDataInputStream inputStream = 
keyGroupStateHandle.openInputStream()) {
                        namedStates = 
InstantiationUtil.deserializeObject(inputStream, userCodeClassLoader);
                }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index d8bba59..117c70d 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -41,6 +41,7 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
@@ -2346,13 +2347,13 @@ public class CheckpointCoordinatorTest {
                        ChainedStateHandle<StreamStateHandle> operatorState = 
taskStateHandles.getLegacyOperatorState();
                        List<Collection<OperatorStateHandle>> opStateBackend = 
taskStateHandles.getManagedOperatorState();
                        List<Collection<OperatorStateHandle>> opStateRaw = 
taskStateHandles.getRawOperatorState();
-                       Collection<KeyGroupsStateHandle> keyGroupStateBackend = 
taskStateHandles.getManagedKeyedState();
-                       Collection<KeyGroupsStateHandle> keyGroupStateRaw = 
taskStateHandles.getRawKeyedState();
+                       Collection<KeyedStateHandle> keyedStateBackend = 
taskStateHandles.getManagedKeyedState();
+                       Collection<KeyedStateHandle> keyGroupStateRaw = 
taskStateHandles.getRawKeyedState();
 
                        actualOpStatesBackend.add(opStateBackend);
                        actualOpStatesRaw.add(opStateRaw);
                        assertNull(operatorState);
-                       
compareKeyedState(Collections.singletonList(originalKeyedStateBackend), 
keyGroupStateBackend);
+                       
compareKeyedState(Collections.singletonList(originalKeyedStateBackend), 
keyedStateBackend);
                        
compareKeyedState(Collections.singletonList(originalKeyedStateRaw), 
keyGroupStateRaw);
                }
                comparePartitionableState(expectedOpStatesBackend, 
actualOpStatesBackend);
@@ -2690,32 +2691,38 @@ public class CheckpointCoordinatorTest {
 
                        KeyGroupsStateHandle expectPartitionedKeyGroupState = 
generateKeyGroupState(
                                        jobVertexID, keyGroupPartitions.get(i), 
false);
-                       Collection<KeyGroupsStateHandle> 
actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState();
+                       Collection<KeyedStateHandle> 
actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState();
                        
compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), 
actualPartitionedKeyGroupState);
                }
        }
 
        public static void compareKeyedState(
                        Collection<KeyGroupsStateHandle> 
expectPartitionedKeyGroupState,
-                       Collection<KeyGroupsStateHandle> 
actualPartitionedKeyGroupState) throws Exception {
+                       Collection<? extends KeyedStateHandle> 
actualPartitionedKeyGroupState) throws Exception {
 
                KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = 
expectPartitionedKeyGroupState.iterator().next();
-               int expectedTotalKeyGroups = 
expectedHeadOpKeyGroupStateHandle.getNumberOfKeyGroups();
+               int expectedTotalKeyGroups = 
expectedHeadOpKeyGroupStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
                int actualTotalKeyGroups = 0;
-               for(KeyGroupsStateHandle keyGroupsStateHandle: 
actualPartitionedKeyGroupState) {
-                       actualTotalKeyGroups += 
keyGroupsStateHandle.getNumberOfKeyGroups();
+               for(KeyedStateHandle keyedStateHandle: 
actualPartitionedKeyGroupState) {
+                       assertTrue(keyedStateHandle instanceof 
KeyGroupsStateHandle);
+
+                       actualTotalKeyGroups += 
keyedStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
                }
 
                assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups);
 
                try (FSDataInputStream inputStream = 
expectedHeadOpKeyGroupStateHandle.openInputStream()) {
-                       for (int groupId : 
expectedHeadOpKeyGroupStateHandle.keyGroups()) {
+                       for (int groupId : 
expectedHeadOpKeyGroupStateHandle.getKeyGroupRange()) {
                                long offset = 
expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
                                inputStream.seek(offset);
                                int expectedKeyGroupState =
                                                
InstantiationUtil.deserializeObject(inputStream, 
Thread.currentThread().getContextClassLoader());
-                               for (KeyGroupsStateHandle 
oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) {
-                                       if 
(oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) {
+                               for (KeyedStateHandle oneActualKeyedStateHandle 
: actualPartitionedKeyGroupState) {
+
+                                       assertTrue(oneActualKeyedStateHandle 
instanceof KeyGroupsStateHandle);
+
+                                       KeyGroupsStateHandle 
oneActualKeyGroupStateHandle = (KeyGroupsStateHandle) oneActualKeyedStateHandle;
+                                       if 
(oneActualKeyGroupStateHandle.getKeyGroupRange().contains(groupId)) {
                                                long actualOffset = 
oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
                                                try (FSDataInputStream 
actualInputStream = oneActualKeyGroupStateHandle.openInputStream()) {
                                                        
actualInputStream.seek(actualOffset);

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 18b07eb..7e0a7c1 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -31,6 +31,7 @@ import 
org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
@@ -68,7 +69,7 @@ public class CheckpointStateRestoreTest {
                        final ChainedStateHandle<StreamStateHandle> 
serializedState = CheckpointCoordinatorTest.generateChainedStateHandle(new 
SerializableObject());
                        KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0);
                        List<SerializableObject> testStates = 
Collections.singletonList(new SerializableObject());
-                       final KeyGroupsStateHandle serializedKeyGroupStates = 
CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
+                       final KeyedStateHandle serializedKeyGroupStates = 
CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
 
                        final JobID jid = new JobID();
                        final JobVertexID statefulId = new JobVertexID();

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java
index 6ab8620..1ecb2e3 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java
@@ -37,6 +37,7 @@ import org.apache.flink.runtime.checkpoint.TaskState;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
@@ -58,6 +59,7 @@ import java.util.concurrent.ThreadLocalRandom;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 
 @SuppressWarnings("deprecation")
 public class MigrationV0ToV1Test {
@@ -154,9 +156,15 @@ public class MigrationV0ToV1Test {
                                        }
 
                                        //check keyed state
-                                       KeyGroupsStateHandle 
keyGroupsStateHandle = subtaskState.getManagedKeyedState();
+                                       KeyedStateHandle keyedStateHandle = 
subtaskState.getManagedKeyedState();
+
                                        if (t % 3 != 0) {
-                                               assertEquals(1, 
keyGroupsStateHandle.getNumberOfKeyGroups());
+
+                                               assertTrue(keyedStateHandle 
instanceof KeyGroupsStateHandle);
+
+                                               KeyGroupsStateHandle 
keyGroupsStateHandle = (KeyGroupsStateHandle) keyedStateHandle;
+
+                                               assertEquals(1, 
keyGroupsStateHandle.getKeyGroupRange().getNumberOfKeyGroups());
                                                assertEquals(p, 
keyGroupsStateHandle.getGroupRangeOffsets().getKeyGroupRange().getStartKeyGroup());
 
                                                ByteStreamStateHandle 
stateHandle =
@@ -172,7 +180,7 @@ public class MigrationV0ToV1Test {
                                                        assertEquals(p, 
data[1]);
                                                }
                                        } else {
-                                               assertEquals(null, 
keyGroupsStateHandle);
+                                               assertEquals(null, 
keyedStateHandle);
                                        }
                                }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
index 0c4ed74..cee0b02 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
@@ -135,7 +135,7 @@ public class KeyedStateCheckpointOutputStreamTest {
                int count = 0;
                try (FSDataInputStream in = fullHandle.openInputStream()) {
                        DataInputView div = new DataInputViewStreamWrapper(in);
-                       for (int kg : fullHandle.keyGroups()) {
+                       for (int kg : fullHandle.getKeyGroupRange()) {
                                long off = fullHandle.getOffsetForKeyGroup(kg);
                                if (off >= 0) {
                                        in.seek(off);
@@ -152,7 +152,7 @@ public class KeyedStateCheckpointOutputStreamTest {
                int count = 0;
                try (FSDataInputStream in = fullHandle.openInputStream()) {
                        DataInputView div = new DataInputViewStreamWrapper(in);
-                       for (int kg : fullHandle.keyGroups()) {
+                       for (int kg : fullHandle.getKeyGroupRange()) {
                                long off = fullHandle.getOffsetForKeyGroup(kg);
                                in.seek(off);
                                Assert.assertEquals(kg, div.readInt());

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 22bb715..ccc1eae 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -143,13 +143,13 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                                env.getTaskKvStateRegistry());
        }
 
-       protected <K> AbstractKeyedStateBackend<K> 
restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyGroupsStateHandle 
state) throws Exception {
+       protected <K> AbstractKeyedStateBackend<K> 
restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyedStateHandle state) 
throws Exception {
                return restoreKeyedBackend(keySerializer, state, new 
DummyEnvironment("test", 1, 0));
        }
 
        protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(
                        TypeSerializer<K> keySerializer,
-                       KeyGroupsStateHandle state,
+                       KeyedStateHandle state,
                        Environment env) throws Exception {
                return restoreKeyedBackend(
                                keySerializer,
@@ -163,7 +163,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        TypeSerializer<K> keySerializer,
                        int numberOfKeyGroups,
                        KeyGroupRange keyGroupRange,
-                       List<KeyGroupsStateHandle> state,
+                       List<KeyedStateHandle> state,
                        Environment env) throws Exception {
 
                AbstractKeyedStateBackend<K> backend = 
getStateBackend().createKeyedStateBackend(
@@ -436,7 +436,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                backend.setCurrentKey(2);
                state.update(new TestPojo("u2", 2));
 
-               KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(
+               KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
                                682375462378L,
                                2,
                                streamFactory,
@@ -497,7 +497,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                backend.setCurrentKey(2);
                state.update(new TestPojo("u2", 2));
 
-               KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(
+               KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
                                682375462378L,
                                2,
                                streamFactory,
@@ -524,7 +524,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                // update to test state backends that eagerly serialize, such 
as RocksDB
                state.update(new TestPojo("u1", 11));
 
-               KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(
+               KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(
                                682375462378L,
                                2,
                                streamFactory,
@@ -585,7 +585,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                backend.setCurrentKey(2);
                state.update(new TestPojo("u2", 2));
 
-               KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(
+               KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
                                682375462378L,
                                2,
                                streamFactory,
@@ -611,7 +611,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                // update to test state backends that eagerly serialize, such 
as RocksDB
                state.update(new TestPojo("u1", 11));
 
-               KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(
+               KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(
                                682375462378L,
                                2,
                                streamFactory,
@@ -670,7 +670,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                assertEquals("1", getSerializedValue(kvState, 1, keySerializer, 
VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
                // draw a snapshot
-               KeyGroupsStateHandle snapshot1 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+               KeyedStateHandle snapshot1 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                // make some more modifications
                backend.setCurrentKey(1);
@@ -681,7 +681,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                state.update("u3");
 
                // draw another snapshot
-               KeyGroupsStateHandle snapshot2 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+               KeyedStateHandle snapshot2 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                // validate the original state
                backend.setCurrentKey(1);
@@ -880,7 +880,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                assertEquals(13, (int) state2.value());
 
                // draw a snapshot
-               KeyGroupsStateHandle snapshot1 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+               KeyedStateHandle snapshot1 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                backend.dispose();
                backend = restoreKeyedBackend(
@@ -952,7 +952,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                assertEquals(42L, (long) state.value());
 
                // draw a snapshot
-               KeyGroupsStateHandle snapshot1 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+               KeyedStateHandle snapshot1 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                backend.dispose();
                backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot1);
@@ -997,7 +997,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                assertEquals("1", joiner.join(getSerializedList(kvState, 1, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
                // draw a snapshot
-               KeyGroupsStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+               KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                // make some more modifications
                backend.setCurrentKey(1);
@@ -1008,7 +1008,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                state.add("u3");
 
                // draw another snapshot
-               KeyGroupsStateHandle snapshot2 = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+               KeyedStateHandle snapshot2 = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                // validate the original state
                backend.setCurrentKey(1);
@@ -1091,7 +1091,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                assertEquals("1", getSerializedValue(kvState, 1, keySerializer, 
VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
                // draw a snapshot
-               KeyGroupsStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+               KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                // make some more modifications
                backend.setCurrentKey(1);
@@ -1102,7 +1102,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                state.add("u3");
 
                // draw another snapshot
-               KeyGroupsStateHandle snapshot2 = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+               KeyedStateHandle snapshot2 = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                // validate the original state
                backend.setCurrentKey(1);
@@ -1188,7 +1188,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                assertEquals("Fold-Initial:,1", getSerializedValue(kvState, 1, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
                // draw a snapshot
-               KeyGroupsStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+               KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                // make some more modifications
                backend.setCurrentKey(1);
@@ -1200,7 +1200,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                state.add(103);
 
                // draw another snapshot
-               KeyGroupsStateHandle snapshot2 = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+               KeyedStateHandle snapshot2 = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                // validate the original state
                backend.setCurrentKey(1);
@@ -1287,7 +1287,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                                getSerializedMap(kvState, 1, keySerializer, 
VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, 
userValueSerializer));
 
                // draw a snapshot
-               KeyGroupsStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+               KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                // make some more modifications
                backend.setCurrentKey(1);
@@ -1299,7 +1299,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                state.putAll(new HashMap<Integer, String>() {{ put(1031, 
"1031"); put(1032, "1032"); }});
 
                // draw another snapshot
-               KeyGroupsStateHandle snapshot2 = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+               KeyedStateHandle snapshot2 = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                // validate the original state
                backend.setCurrentKey(1);
@@ -1606,13 +1606,13 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                state.update("ShouldBeInSecondHalf");
 
 
-               KeyGroupsStateHandle snapshot = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(0, 0, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+               KeyedStateHandle snapshot = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(0, 0, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
-               List<KeyGroupsStateHandle> firstHalfKeyGroupStates = 
StateAssignmentOperation.getKeyGroupsStateHandles(
+               List<KeyedStateHandle> firstHalfKeyGroupStates = 
StateAssignmentOperation.getKeyedStateHandles(
                                Collections.singletonList(snapshot),
                                
KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 
2, 0));
 
-               List<KeyGroupsStateHandle> secondHalfKeyGroupStates = 
StateAssignmentOperation.getKeyGroupsStateHandles(
+               List<KeyedStateHandle> secondHalfKeyGroupStates = 
StateAssignmentOperation.getKeyedStateHandles(
                                Collections.singletonList(snapshot),
                                
KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 
2, 1));
 
@@ -1672,7 +1672,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.update("2");
 
                        // draw a snapshot
-                       KeyGroupsStateHandle snapshot1 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+                       KeyedStateHandle snapshot1 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                        backend.dispose();
                        // restore the first snapshot and validate it
@@ -1723,7 +1723,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.add("2");
 
                        // draw a snapshot
-                       KeyGroupsStateHandle snapshot1 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+                       KeyedStateHandle snapshot1 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                        backend.dispose();
                        // restore the first snapshot and validate it
@@ -1776,7 +1776,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.add("2");
 
                        // draw a snapshot
-                       KeyGroupsStateHandle snapshot1 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+                       KeyedStateHandle snapshot1 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                        backend.dispose();
                        // restore the first snapshot and validate it
@@ -1827,7 +1827,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.put("2", "Second");
 
                        // draw a snapshot
-                       KeyGroupsStateHandle snapshot1 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+                       KeyedStateHandle snapshot1 = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                        backend.dispose();
                        // restore the first snapshot and validate it
@@ -2093,7 +2093,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                                eq(env.getJobID()), eq(env.getJobVertexId()), 
eq(expectedKeyGroupRange), eq("banana"), any(KvStateID.class));
 
 
-               KeyGroupsStateHandle snapshot = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+               KeyedStateHandle snapshot = 
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
 
                backend.dispose();
 
@@ -2124,7 +2124,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        ListStateDescriptor<String> kvId = new 
ListStateDescriptor<>("id", String.class);
 
                        // draw a snapshot
-                       KeyGroupsStateHandle snapshot =
+                       KeyedStateHandle snapshot =
                                        
FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 1, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
                        assertNull(snapshot);
                        backend.dispose();
@@ -2152,7 +2152,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                streamFactory.setWaiterLatch(waiter);
 
                AbstractKeyedStateBackend<Integer> backend = null;
-               KeyGroupsStateHandle stateHandle = null;
+               KeyedStateHandle stateHandle = null;
 
                try {
                        backend = createKeyedBackend(IntSerializer.INSTANCE);
@@ -2167,7 +2167,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                                valueState.update(i);
                        }
 
-                       RunnableFuture<KeyGroupsStateHandle> snapshot =
+                       RunnableFuture<KeyedStateHandle> snapshot =
                                        backend.snapshot(0L, 0L, streamFactory, 
CheckpointOptions.forFullCheckpoint());
                        Thread runner = new Thread(snapshot);
                        runner.start();
@@ -2249,7 +2249,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                                valueState.update(i);
                        }
 
-                       RunnableFuture<KeyGroupsStateHandle> snapshot =
+                       RunnableFuture<KeyedStateHandle> snapshot =
                                        backend.snapshot(0L, 0L, streamFactory, 
CheckpointOptions.forFullCheckpoint());
 
                        Thread runner = new Thread(snapshot);
@@ -2367,7 +2367,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                }
        }
 
-       private KeyGroupsStateHandle 
runSnapshot(RunnableFuture<KeyGroupsStateHandle> snapshotRunnableFuture) throws 
Exception {
+       private KeyedStateHandle runSnapshot(RunnableFuture<KeyedStateHandle> 
snapshotRunnableFuture) throws Exception {
                if(!snapshotRunnableFuture.isDone()) {
                        Thread runner = new Thread(snapshotRunnableFuture);
                        runner.start();

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
index da0666a..3754d63 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.internal.InternalListState;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
@@ -63,7 +64,7 @@ public class HeapKeyedStateBackendSnapshotMigrationTest 
extends HeapStateBackend
                        try (BufferedInputStream bis = new 
BufferedInputStream((new FileInputStream(resource.getFile())))) {
                                stateHandle = 
InstantiationUtil.deserializeObject(bis, 
Thread.currentThread().getContextClassLoader());
                        }
-                       
keyedBackend.restore(Collections.singleton(stateHandle));
+                       
keyedBackend.restore(Collections.<KeyedStateHandle>singleton(stateHandle));
                        final ListStateDescriptor<Long> stateDescr = new 
ListStateDescriptor<>("my-state", Long.class);
                        stateDescr.initializeSerializerUnlessSet(new 
ExecutionConfig());
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
index e40a59b..a6a89b5 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
@@ -18,7 +18,6 @@
 
 package org.apache.flink.streaming.api.operators;
 
-import java.io.IOException;
 import org.apache.commons.io.IOUtils;
 import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
 import org.apache.flink.annotation.PublicEvolving;
@@ -47,9 +46,9 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
 import org.apache.flink.runtime.state.KeyGroupsList;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateInitializationContext;
@@ -70,6 +69,7 @@ import org.apache.flink.util.OutputTag;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.IOException;
 import java.io.Serializable;
 import java.util.Collection;
 import java.util.ConcurrentModificationException;
@@ -198,7 +198,7 @@ public abstract class AbstractStreamOperator<OUT>
        @Override
        public final void initializeState(OperatorStateHandles stateHandles) 
throws Exception {
 
-               Collection<KeyGroupsStateHandle> keyedStateHandlesRaw = null;
+               Collection<KeyedStateHandle> keyedStateHandlesRaw = null;
                Collection<OperatorStateHandle> operatorStateHandlesRaw = null;
                Collection<OperatorStateHandle> operatorStateHandlesBackend = 
null;
 
@@ -473,6 +473,7 @@ public abstract class AbstractStreamOperator<OUT>
                        // and then initialize the timer services
                        for (KeyGroupStatePartitionStreamProvider 
streamProvider : context.getRawKeyedStateInputs()) {
                                int keyGroupIdx = 
streamProvider.getKeyGroupId();
+
                                
checkArgument(localKeyGroupRange.contains(keyGroupIdx),
                                        "Key Group " + keyGroupIdx + " does not 
belong to the local range.");
 

Reply via email to