This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 9d273a3fcda4033e1a385c0ff0c4a2b7ad640721
Author: Stefan Richter <[email protected]>
AuthorDate: Fri Jul 27 14:49:27 2018 +0200

    [FLINK-9887][state] Integrate priority queue state with existing serializer 
upgrade mechanism
    
    This closes #6467.
---
 .../apache/flink/util/StateMigrationException.java |   6 +
 .../runtime/state/AbstractKeyedStateBackend.java   |   2 +-
 .../runtime/state/DefaultOperatorStateBackend.java |   4 +-
 .../RegisteredKeyValueStateBackendMetaInfo.java    |  10 +-
 .../runtime/state/heap/HeapKeyedStateBackend.java  | 459 +++++++++++------
 .../state/heap/HeapPriorityQueueSetFactory.java    |   2 +-
 .../HeapPriorityQueueSnapshotRestoreWrapper.java   |  22 +
 .../state/InternalPriorityQueueTestBase.java       |  69 ++-
 .../runtime/state/MemoryStateBackendTest.java      |  16 +-
 .../flink/runtime/state/StateBackendTestBase.java  | 549 ++++++++++++++-------
 ...HeapKeyedStateBackendSnapshotMigrationTest.java |   2 +-
 .../state/ttl/mock/MockKeyedStateBackend.java      |   2 +-
 .../streaming/state/RocksDBKeyedStateBackend.java  |  51 +-
 .../streaming/state/RocksDBStateBackendTest.java   |   6 +
 flink-streaming-java/pom.xml                       |   2 +
 .../api/operators/InternalTimeServiceManager.java  |  18 +-
 .../operators/StreamTaskStateInitializerImpl.java  |   1 -
 .../streaming/api/operators/TimerSerializer.java   |  57 ++-
 .../operators/InternalTimeServiceManagerTest.java  |  31 +-
 .../api/operators/TimerSerializerTest.java         |  62 +++
 .../operators/windowing/TriggerTestHarness.java    |   4 +-
 .../KeyedOneInputStreamOperatorTestHarness.java    |   4 +-
 .../KeyedTwoInputStreamOperatorTestHarness.java    |   2 +-
 23 files changed, 987 insertions(+), 394 deletions(-)

diff --git 
a/flink-core/src/main/java/org/apache/flink/util/StateMigrationException.java 
b/flink-core/src/main/java/org/apache/flink/util/StateMigrationException.java
index 00e0e73..12f3ee4 100644
--- 
a/flink-core/src/main/java/org/apache/flink/util/StateMigrationException.java
+++ 
b/flink-core/src/main/java/org/apache/flink/util/StateMigrationException.java
@@ -24,6 +24,8 @@ package org.apache.flink.util;
 public class StateMigrationException extends FlinkException {
        private static final long serialVersionUID = 8268516412747670839L;
 
+       public static final String MIGRATION_NOT_SUPPORTED_MSG = "State 
migration is currently not supported.";
+
        public StateMigrationException(String message) {
                super(message);
        }
@@ -35,4 +37,8 @@ public class StateMigrationException extends FlinkException {
        public StateMigrationException(String message, Throwable cause) {
                super(message, cause);
        }
+
+       public static StateMigrationException notSupported() {
+               return new StateMigrationException(MIGRATION_NOT_SUPPORTED_MSG);
+       }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
index 17d24f77..1c2d2a3 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
@@ -313,7 +313,7 @@ public abstract class AbstractKeyedStateBackend<K> 
implements
         * Returns the total number of state entries across all keys/namespaces.
         */
        @VisibleForTesting
-       public abstract int numStateEntries();
+       public abstract int numKeyValueStateEntries();
 
        // TODO remove this once heap-based timers are working with RocksDB 
incremental snapshots!
        public boolean requiresLegacySynchronousTimerSnapshots() {
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
index f1d0b57..dfff50d 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
@@ -260,7 +260,7 @@ public class DefaultOperatorStateBackend implements 
OperatorStateBackend {
                                // the new serializer; we're deliberately 
failing here for now to have equal functionality with
                                // the RocksDB backend to avoid confusion for 
users.
 
-                               throw new StateMigrationException("State 
migration isn't supported, yet.");
+                               throw StateMigrationException.notSupported();
                        }
                }
 
@@ -781,7 +781,7 @@ public class DefaultOperatorStateBackend implements 
OperatorStateBackend {
                                // the new serializer; we're deliberately 
failing here for now to have equal functionality with
                                // the RocksDB backend to avoid confusion for 
users.
 
-                               throw new StateMigrationException("State 
migration isn't supported, yet.");
+                               throw StateMigrationException.notSupported();
                        }
                }
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
index d49a05c..b0248fc 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
@@ -144,6 +144,12 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> 
extends RegisteredStat
                TypeSerializer<N> newNamespaceSerializer,
                StateDescriptor<?, S> newStateDescriptor) throws 
StateMigrationException {
 
+               
Preconditions.checkState(restoredStateMetaInfoSnapshot.getBackendStateType()
+                               == 
StateMetaInfoSnapshot.BackendStateType.KEY_VALUE,
+                       "Incompatible state types. " +
+                               "Was [" + 
restoredStateMetaInfoSnapshot.getBackendStateType() + "], " +
+                               "registered as [" + 
StateMetaInfoSnapshot.BackendStateType.KEY_VALUE + "].");
+
                Preconditions.checkState(
                        Objects.equals(newStateDescriptor.getName(), 
restoredStateMetaInfoSnapshot.getName()),
                        "Incompatible state names. " +
@@ -160,7 +166,7 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> 
extends RegisteredStat
 
                        Preconditions.checkState(
                                newStateDescriptor.getType() == restoredType,
-                               "Incompatible state types. " +
+                               "Incompatible key/value state types. " +
                                        "Was [" + restoredType + "], " +
                                        "registered with [" + 
newStateDescriptor.getType() + "].");
                }
@@ -184,7 +190,7 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> 
extends RegisteredStat
 
                if (namespaceCompatibility.isRequiresMigration() || 
stateCompatibility.isRequiresMigration()) {
                        // TODO state migration currently isn't possible.
-                       throw new StateMigrationException("State migration 
isn't supported, yet.");
+                       throw StateMigrationException.notSupported();
                } else {
                        return new RegisteredKeyValueStateBackendMetaInfo<>(
                                newStateDescriptor.getType(),
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 2c6101e..34c9698 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -28,6 +28,7 @@ import 
org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.CompatibilityResult;
 import org.apache.flink.api.common.typeutils.CompatibilityUtil;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.UnloadableDummyTypeSerializer;
@@ -79,6 +80,7 @@ import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnull;
 
+import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
 import java.util.ArrayList;
@@ -111,60 +113,15 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                        Tuple2.of(FoldingStateDescriptor.class, (StateFactory) 
HeapFoldingState::create)
                ).collect(Collectors.toMap(t -> t.f0, t -> t.f1));
 
-       @SuppressWarnings("unchecked")
-       @Nonnull
-       @Override
-       public <T extends HeapPriorityQueueElement & PriorityComparable & 
Keyed> KeyGroupedInternalPriorityQueue<T> create(
-               @Nonnull String stateName,
-               @Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
-
-               final StateSnapshotRestore snapshotRestore = 
registeredStates.get(stateName);
-
-               if (snapshotRestore instanceof 
HeapPriorityQueueSnapshotRestoreWrapper) {
-                       //TODO Serializer upgrade story!?
-                       return ((HeapPriorityQueueSnapshotRestoreWrapper<T>) 
snapshotRestore).getPriorityQueue();
-               } else if (snapshotRestore != null) {
-                       throw new IllegalStateException("Already found a 
different state type registered under this name: " + 
snapshotRestore.getClass());
-               }
-
-               final RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo =
-                       new 
RegisteredPriorityQueueStateBackendMetaInfo<>(stateName, 
byteOrderedElementSerializer);
-
-               return createInternal(metaInfo);
-       }
-
-       @Nonnull
-       private <T extends HeapPriorityQueueElement & PriorityComparable & 
Keyed> KeyGroupedInternalPriorityQueue<T> createInternal(
-               RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo) {
-
-               final String stateName = metaInfo.getName();
-               final HeapPriorityQueueSet<T> priorityQueue = 
priorityQueueSetFactory.create(
-                       stateName,
-                       metaInfo.getElementSerializer());
-
-               HeapPriorityQueueSnapshotRestoreWrapper<T> wrapper =
-                       new HeapPriorityQueueSnapshotRestoreWrapper<>(
-                               priorityQueue,
-                               metaInfo,
-                               KeyExtractorFunction.forKeyedObjects(),
-                               keyGroupRange,
-                               numberOfKeyGroups);
-
-               registeredStates.put(stateName, wrapper);
-               return priorityQueue;
-       }
-
-       private interface StateFactory {
-               <K, N, SV, S extends State, IS extends S> IS createState(
-                       StateDescriptor<S, SV> stateDesc,
-                       StateTable<K, N, SV> stateTable,
-                       TypeSerializer<K> keySerializer) throws Exception;
-       }
+       /**
+        * Map of registered Key/Value states.
+        */
+       private final Map<String, StateTable<K, ?, ?>> registeredKVStates;
 
        /**
-        * Map of registered states for snapshot/restore.
+        * Map of registered priority queue set states.
         */
-       private final Map<String, StateSnapshotRestore> registeredStates = new 
HashMap<>();
+       private final Map<String, HeapPriorityQueueSnapshotRestoreWrapper> 
registeredPQStates;
 
        /**
         * Map of state names to their corresponding restored state meta info.
@@ -172,7 +129,7 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
         * <p>TODO this map can be removed when eager-state registration is in 
place.
         * TODO we currently need this cached to check state migration 
strategies when new serializers are registered.
         */
-       private final Map<String, StateMetaInfoSnapshot> 
restoredKvStateMetaInfos;
+       private final Map<StateUID, StateMetaInfoSnapshot> 
restoredStateMetaInfo;
 
        /**
         * The configuration for local recovery.
@@ -203,6 +160,9 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
 
                super(kvStateRegistry, keySerializer, userCodeClassLoader,
                        numberOfKeyGroups, keyGroupRange, executionConfig, 
ttlTimeProvider);
+
+               this.registeredKVStates = new HashMap<>();
+               this.registeredPQStates = new HashMap<>();
                this.localRecoveryConfig = 
Preconditions.checkNotNull(localRecoveryConfig);
 
                SnapshotStrategySynchronicityBehavior<K> synchronicityTrait = 
asynchronousSnapshots ?
@@ -211,7 +171,7 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
 
                this.snapshotStrategy = new 
HeapSnapshotStrategy(synchronicityTrait);
                LOG.info("Initializing heap keyed state backend with stream 
factory.");
-               this.restoredKvStateMetaInfos = new HashMap<>();
+               this.restoredStateMetaInfo = new HashMap<>();
                this.priorityQueueSetFactory = priorityQueueSetFactory;
        }
 
@@ -219,17 +179,85 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
        //  state backend operations
        // 
------------------------------------------------------------------------
 
+       @SuppressWarnings("unchecked")
+       @Nonnull
+       @Override
+       public <T extends HeapPriorityQueueElement & PriorityComparable & 
Keyed> KeyGroupedInternalPriorityQueue<T> create(
+               @Nonnull String stateName,
+               @Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
+
+               final HeapPriorityQueueSnapshotRestoreWrapper existingState = 
registeredPQStates.get(stateName);
+
+               if (existingState != null) {
+                       // TODO we implement the simple way of supporting the 
current functionality, mimicking keyed state
+                       // because this should be reworked in FLINK-9376 and 
then we should have a common algorithm over
+                       // StateMetaInfoSnapshot that avoids this code 
duplication.
+                       StateMetaInfoSnapshot restoredMetaInfoSnapshot =
+                               
restoredStateMetaInfo.get(StateUID.of(stateName, 
StateMetaInfoSnapshot.BackendStateType.PRIORITY_QUEUE));
+
+                       Preconditions.checkState(
+                               restoredMetaInfoSnapshot != null,
+                               "Requested to check compatibility of a restored 
RegisteredKeyedBackendStateMetaInfo," +
+                                       " but its corresponding restored 
snapshot cannot be found.");
+
+                       StateMetaInfoSnapshot.CommonSerializerKeys 
serializerKey =
+                               
StateMetaInfoSnapshot.CommonSerializerKeys.VALUE_SERIALIZER;
+
+                       CompatibilityResult<T> compatibilityResult = 
CompatibilityUtil.resolveCompatibilityResult(
+                               
restoredMetaInfoSnapshot.getTypeSerializer(serializerKey),
+                               null,
+                               
restoredMetaInfoSnapshot.getTypeSerializerConfigSnapshot(serializerKey),
+                               byteOrderedElementSerializer);
+
+                       if (compatibilityResult.isRequiresMigration()) {
+                               throw new 
FlinkRuntimeException(StateMigrationException.notSupported());
+                       } else {
+                               registeredPQStates.put(
+                                       stateName,
+                                       
existingState.forUpdatedSerializer(byteOrderedElementSerializer));
+                       }
+
+                       return existingState.getPriorityQueue();
+               } else {
+                       final RegisteredPriorityQueueStateBackendMetaInfo<T> 
metaInfo =
+                               new 
RegisteredPriorityQueueStateBackendMetaInfo<>(stateName, 
byteOrderedElementSerializer);
+                       return createInternal(metaInfo);
+               }
+       }
+
+       @Nonnull
+       private <T extends HeapPriorityQueueElement & PriorityComparable & 
Keyed> KeyGroupedInternalPriorityQueue<T> createInternal(
+               RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo) {
+
+               final String stateName = metaInfo.getName();
+               final HeapPriorityQueueSet<T> priorityQueue = 
priorityQueueSetFactory.create(
+                       stateName,
+                       metaInfo.getElementSerializer());
+
+               HeapPriorityQueueSnapshotRestoreWrapper<T> wrapper =
+                       new HeapPriorityQueueSnapshotRestoreWrapper<>(
+                               priorityQueue,
+                               metaInfo,
+                               KeyExtractorFunction.forKeyedObjects(),
+                               keyGroupRange,
+                               numberOfKeyGroups);
+
+               registeredPQStates.put(stateName, wrapper);
+               return priorityQueue;
+       }
+
        private <N, V> StateTable<K, N, V> tryRegisterStateTable(
                        TypeSerializer<N> namespaceSerializer, 
StateDescriptor<?, V> stateDesc) throws StateMigrationException {
 
                @SuppressWarnings("unchecked")
-               StateTable<K, N, V> stateTable = (StateTable<K, N, V>) 
registeredStates.get(stateDesc.getName());
+               StateTable<K, N, V> stateTable = (StateTable<K, N, V>) 
registeredKVStates.get(stateDesc.getName());
 
                RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo;
                if (stateTable != null) {
                        @SuppressWarnings("unchecked")
                        StateMetaInfoSnapshot restoredMetaInfoSnapshot =
-                               
restoredKvStateMetaInfos.get(stateDesc.getName());
+                               restoredStateMetaInfo.get(
+                                       StateUID.of(stateDesc.getName(), 
StateMetaInfoSnapshot.BackendStateType.KEY_VALUE));
 
                        Preconditions.checkState(
                                restoredMetaInfoSnapshot != null,
@@ -250,7 +278,7 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                                stateDesc.getSerializer());
 
                        stateTable = 
snapshotStrategy.newStateTable(newMetaInfo);
-                       registeredStates.put(stateDesc.getName(), stateTable);
+                       registeredKVStates.put(stateDesc.getName(), stateTable);
                }
 
                return stateTable;
@@ -259,20 +287,17 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
        @SuppressWarnings("unchecked")
        @Override
        public <N> Stream<K> getKeys(String state, N namespace) {
-               if (!registeredStates.containsKey(state)) {
+               if (!registeredKVStates.containsKey(state)) {
                        return Stream.empty();
                }
 
-               final StateSnapshotRestore stateSnapshotRestore = 
registeredStates.get(state);
-               if (!(stateSnapshotRestore instanceof StateTable)) {
-                       return Stream.empty();
-               }
+               final StateSnapshotRestore stateSnapshotRestore = 
registeredKVStates.get(state);
                StateTable<K, N, ?> table = (StateTable<K, N, ?>) 
stateSnapshotRestore;
                return table.getKeys(namespace);
        }
 
        private boolean hasRegisteredState() {
-               return !registeredStates.isEmpty();
+               return !(registeredKVStates.isEmpty() && 
registeredPQStates.isEmpty());
        }
 
        @Override
@@ -318,9 +343,9 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
        @SuppressWarnings({"unchecked"})
        private void restorePartitionedState(Collection<KeyedStateHandle> 
state) throws Exception {
 
-               final Map<Integer, String> kvStatesById = new HashMap<>();
-               int numRegisteredKvStates = 0;
-               registeredStates.clear();
+               final Map<Integer, StateMetaInfoSnapshot> kvStatesById = new 
HashMap<>();
+               registeredKVStates.clear();
+               registeredPQStates.clear();
 
                boolean keySerializerRestored = false;
 
@@ -369,70 +394,131 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                                }
 
                                List<StateMetaInfoSnapshot> restoredMetaInfos =
-                                               
serializationProxy.getStateMetaInfoSnapshots();
-
-                               for (StateMetaInfoSnapshot restoredMetaInfo : 
restoredMetaInfos) {
-                                       
restoredKvStateMetaInfos.put(restoredMetaInfo.getName(), restoredMetaInfo);
-
-                                       StateSnapshotRestore snapshotRestore = 
registeredStates.get(restoredMetaInfo.getName());
+                                       
serializationProxy.getStateMetaInfoSnapshots();
 
-                                       //important: only create a new table we 
did not already create it previously
-                                       if (null == snapshotRestore) {
+                               
createOrCheckStateForMetaInfo(restoredMetaInfos, kvStatesById);
 
-                                               if 
(restoredMetaInfo.getBackendStateType() == 
StateMetaInfoSnapshot.BackendStateType.KEY_VALUE) {
-                                                       
RegisteredKeyValueStateBackendMetaInfo<?, ?> 
registeredKeyedBackendStateMetaInfo =
-                                                               new 
RegisteredKeyValueStateBackendMetaInfo<>(restoredMetaInfo);
-
-                                                       snapshotRestore = 
snapshotStrategy.newStateTable(registeredKeyedBackendStateMetaInfo);
-                                                       
registeredStates.put(restoredMetaInfo.getName(), snapshotRestore);
-                                               } else {
-                                                       createInternal(new 
RegisteredPriorityQueueStateBackendMetaInfo<>(restoredMetaInfo));
-                                               }
-                                               
kvStatesById.put(numRegisteredKvStates, restoredMetaInfo.getName());
-                                               ++numRegisteredKvStates;
-                                       } else {
-                                               // TODO with eager state 
registration in place, check here for serializer migration strategies
-                                       }
+                               readStateHandleStateData(
+                                       fsDataInputStream,
+                                       inView,
+                                       
keyGroupsStateHandle.getGroupRangeOffsets(),
+                                       kvStatesById, restoredMetaInfos.size(),
+                                       serializationProxy.getReadVersion(),
+                                       
serializationProxy.isUsingKeyGroupCompression());
+                       } finally {
+                               if 
(cancelStreamRegistry.unregisterCloseable(fsDataInputStream)) {
+                                       IOUtils.closeQuietly(fsDataInputStream);
                                }
+                       }
+               }
+       }
 
-                               final StreamCompressionDecorator 
streamCompressionDecorator = serializationProxy.isUsingKeyGroupCompression() ?
-                                       
SnappyStreamCompressionDecorator.INSTANCE : 
UncompressedStreamCompressionDecorator.INSTANCE;
+       private void readStateHandleStateData(
+               FSDataInputStream fsDataInputStream,
+               DataInputViewStreamWrapper inView,
+               KeyGroupRangeOffsets keyGroupOffsets,
+               Map<Integer, StateMetaInfoSnapshot> kvStatesById,
+               int numStates,
+               int readVersion,
+               boolean isCompressed) throws IOException {
 
-                               for (Tuple2<Integer, Long> groupOffset : 
keyGroupsStateHandle.getGroupRangeOffsets()) {
-                                       int keyGroupIndex = groupOffset.f0;
-                                       long offset = groupOffset.f1;
+               final StreamCompressionDecorator streamCompressionDecorator = 
isCompressed ?
+                       SnappyStreamCompressionDecorator.INSTANCE : 
UncompressedStreamCompressionDecorator.INSTANCE;
 
-                                       // Check that restored key groups all 
belong to the backend.
-                                       
Preconditions.checkState(keyGroupRange.contains(keyGroupIndex), "The key group 
must belong to the backend.");
+               for (Tuple2<Integer, Long> groupOffset : keyGroupOffsets) {
+                       int keyGroupIndex = groupOffset.f0;
+                       long offset = groupOffset.f1;
 
-                                       fsDataInputStream.seek(offset);
+                       // Check that restored key groups all belong to the 
backend.
+                       
Preconditions.checkState(keyGroupRange.contains(keyGroupIndex), "The key group 
must belong to the backend.");
 
-                                       int writtenKeyGroupIndex = 
inView.readInt();
+                       fsDataInputStream.seek(offset);
 
-                                       try (InputStream kgCompressionInStream =
-                                                       
streamCompressionDecorator.decorateWithCompression(fsDataInputStream)) {
+                       int writtenKeyGroupIndex = inView.readInt();
+                       Preconditions.checkState(writtenKeyGroupIndex == 
keyGroupIndex,
+                               "Unexpected key-group in restore.");
 
-                                               DataInputViewStreamWrapper 
kgCompressionInView =
-                                                       new 
DataInputViewStreamWrapper(kgCompressionInStream);
+                       try (InputStream kgCompressionInStream =
+                                        
streamCompressionDecorator.decorateWithCompression(fsDataInputStream)) {
 
-                                               
Preconditions.checkState(writtenKeyGroupIndex == keyGroupIndex,
-                                                       "Unexpected key-group 
in restore.");
+                               readKeyGroupStateData(
+                                       kgCompressionInStream,
+                                       kvStatesById,
+                                       keyGroupIndex,
+                                       numStates,
+                                       readVersion);
+                       }
+               }
+       }
 
-                                               for (int i = 0; i < 
restoredMetaInfos.size(); i++) {
-                                                       int kvStateId = 
kgCompressionInView.readShort();
-                                                       StateSnapshotRestore 
registeredState = registeredStates.get(kvStatesById.get(kvStateId));
+       private void readKeyGroupStateData(
+               InputStream inputStream,
+               Map<Integer, StateMetaInfoSnapshot> kvStatesById,
+               int keyGroupIndex,
+               int numStates,
+               int readVersion) throws IOException {
+
+               DataInputViewStreamWrapper inView =
+                       new DataInputViewStreamWrapper(inputStream);
+
+               for (int i = 0; i < numStates; i++) {
+
+                       final int kvStateId = inView.readShort();
+                       final StateMetaInfoSnapshot stateMetaInfoSnapshot = 
kvStatesById.get(kvStateId);
+                       final StateSnapshotRestore registeredState;
+
+                       switch (stateMetaInfoSnapshot.getBackendStateType()) {
+                               case KEY_VALUE:
+                                       registeredState = 
registeredKVStates.get(stateMetaInfoSnapshot.getName());
+                                       break;
+                               case PRIORITY_QUEUE:
+                                       registeredState = 
registeredPQStates.get(stateMetaInfoSnapshot.getName());
+                                       break;
+                               default:
+                                       throw new 
IllegalStateException("Unexpected state type: " +
+                                               
stateMetaInfoSnapshot.getBackendStateType() + ".");
+                       }
 
-                                                       
StateSnapshotKeyGroupReader keyGroupReader =
-                                                               
registeredState.keyGroupReader(serializationProxy.getReadVersion());
+                       StateSnapshotKeyGroupReader keyGroupReader = 
registeredState.keyGroupReader(readVersion);
+                       keyGroupReader.readMappingsInKeyGroup(inView, 
keyGroupIndex);
+               }
+       }
 
-                                                       
keyGroupReader.readMappingsInKeyGroup(kgCompressionInView, keyGroupIndex);
-                                               }
+       private void createOrCheckStateForMetaInfo(
+               List<StateMetaInfoSnapshot> restoredMetaInfo,
+               Map<Integer, StateMetaInfoSnapshot> kvStatesById) {
+
+               for (StateMetaInfoSnapshot metaInfoSnapshot : restoredMetaInfo) 
{
+                       restoredStateMetaInfo.put(
+                               StateUID.of(metaInfoSnapshot.getName(), 
metaInfoSnapshot.getBackendStateType()),
+                               metaInfoSnapshot);
+
+                       final StateSnapshotRestore registeredState;
+
+                       switch (metaInfoSnapshot.getBackendStateType()) {
+                               case KEY_VALUE:
+                                       registeredState = 
registeredKVStates.get(metaInfoSnapshot.getName());
+                                       if (registeredState == null) {
+                                               
RegisteredKeyValueStateBackendMetaInfo<?, ?> 
registeredKeyedBackendStateMetaInfo =
+                                                       new 
RegisteredKeyValueStateBackendMetaInfo<>(metaInfoSnapshot);
+                                               registeredKVStates.put(
+                                                       
metaInfoSnapshot.getName(),
+                                                       
snapshotStrategy.newStateTable(registeredKeyedBackendStateMetaInfo));
                                        }
-                               }
-                       } finally {
-                               if 
(cancelStreamRegistry.unregisterCloseable(fsDataInputStream)) {
-                                       IOUtils.closeQuietly(fsDataInputStream);
-                               }
+                                       break;
+                               case PRIORITY_QUEUE:
+                                       registeredState = 
registeredPQStates.get(metaInfoSnapshot.getName());
+                                       if (registeredState == null) {
+                                               createInternal(new 
RegisteredPriorityQueueStateBackendMetaInfo<>(metaInfoSnapshot));
+                                       }
+                                       break;
+                               default:
+                                       throw new 
IllegalStateException("Unexpected state type: " +
+                                               
metaInfoSnapshot.getBackendStateType() + ".");
+                       }
+
+                       if (registeredState == null) {
+                               kvStatesById.put(kvStatesById.size(), 
metaInfoSnapshot);
                        }
                }
        }
@@ -478,12 +564,10 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
        @VisibleForTesting
        @SuppressWarnings("unchecked")
        @Override
-       public int numStateEntries() {
+       public int numKeyValueStateEntries() {
                int sum = 0;
-               for (StateSnapshotRestore state : registeredStates.values()) {
-                       if (state instanceof StateTable) {
-                               sum += ((StateTable<?, ?, ?>) state).size();
-                       }
+               for (StateSnapshotRestore state : registeredKVStates.values()) {
+                       sum += ((StateTable<?, ?, ?>) state).size();
                }
                return sum;
        }
@@ -492,12 +576,10 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
         * Returns the total number of state entries across all keys for the 
given namespace.
         */
        @VisibleForTesting
-       public int numStateEntries(Object namespace) {
+       public int numKeyValueStateEntries(Object namespace) {
                int sum = 0;
-               for (StateSnapshotRestore state : registeredStates.values()) {
-                       if (state instanceof StateTable) {
-                               sum += ((StateTable<?, ?, ?>) 
state).sizeOfNamespace(namespace);
-                       }
+               for (StateTable<?, ?, ?> state : registeredKVStates.values()) {
+                       sum += state.sizeOfNamespace(namespace);
                }
                return sum;
        }
@@ -574,7 +656,7 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
 
                private final SnapshotStrategySynchronicityBehavior<K> 
snapshotStrategySynchronicityTrait;
 
-               public HeapSnapshotStrategy(
+               HeapSnapshotStrategy(
                        SnapshotStrategySynchronicityBehavior<K> 
snapshotStrategySynchronicityTrait) {
                        this.snapshotStrategySynchronicityTrait = 
snapshotStrategySynchronicityTrait;
                }
@@ -592,28 +674,31 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
 
                        long syncStartTime = System.currentTimeMillis();
 
-                       Preconditions.checkState(registeredStates.size() <= 
Short.MAX_VALUE,
-                               "Too many KV-States: " + 
registeredStates.size() +
-                                       ". Currently at most " + 
Short.MAX_VALUE + " states are supported");
-
-                       List<StateMetaInfoSnapshot> metaInfoSnapshots =
-                               new ArrayList<>(registeredStates.size());
+                       int numStates = registeredKVStates.size() + 
registeredPQStates.size();
 
-                       final Map<String, Integer> kVStateToId = new 
HashMap<>(registeredStates.size());
+                       Preconditions.checkState(numStates <= Short.MAX_VALUE,
+                               "Too many states: " + numStates +
+                                       ". Currently at most " + 
Short.MAX_VALUE + " states are supported");
 
-                       final Map<String, StateSnapshot> 
cowStateStableSnapshots =
-                               new HashMap<>(registeredStates.size());
-
-                       for (Map.Entry<String, StateSnapshotRestore> kvState : 
registeredStates.entrySet()) {
-                               String stateName = kvState.getKey();
-                               kVStateToId.put(stateName, kVStateToId.size());
-                               StateSnapshotRestore state = kvState.getValue();
-                               if (null != state) {
-                                       final StateSnapshot stateSnapshot = 
state.stateSnapshot();
-                                       
metaInfoSnapshots.add(stateSnapshot.getMetaInfoSnapshot());
-                                       cowStateStableSnapshots.put(stateName, 
stateSnapshot);
-                               }
-                       }
+                       final List<StateMetaInfoSnapshot> metaInfoSnapshots = 
new ArrayList<>(numStates);
+                       final Map<StateUID, Integer> stateNamesToId =
+                               new HashMap<>(numStates);
+                       final Map<StateUID, StateSnapshot> 
cowStateStableSnapshots =
+                               new HashMap<>(numStates);
+
+                       processSnapshotMetaInfoForAllStates(
+                               metaInfoSnapshots,
+                               cowStateStableSnapshots,
+                               stateNamesToId,
+                               registeredKVStates,
+                               
StateMetaInfoSnapshot.BackendStateType.KEY_VALUE);
+
+                       processSnapshotMetaInfoForAllStates(
+                               metaInfoSnapshots,
+                               cowStateStableSnapshots,
+                               stateNamesToId,
+                               registeredPQStates,
+                               
StateMetaInfoSnapshot.BackendStateType.PRIORITY_QUEUE);
 
                        final KeyedBackendSerializationProxy<K> 
serializationProxy =
                                new KeyedBackendSerializationProxy<>(
@@ -692,13 +777,14 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                                                        
keyGroupRangeOffsets[keyGroupPos] = localStream.getPos();
                                                        
outView.writeInt(keyGroupId);
 
-                                                       for (Map.Entry<String, 
StateSnapshot> kvState : cowStateStableSnapshots.entrySet()) {
+                                                       for 
(Map.Entry<StateUID, StateSnapshot> stateSnapshot :
+                                                               
cowStateStableSnapshots.entrySet()) {
                                                                
StateSnapshot.StateKeyGroupWriter partitionedSnapshot =
-                                                                       
kvState.getValue().getKeyGroupWriter();
+
+                                                                       
stateSnapshot.getValue().getKeyGroupWriter();
                                                                try 
(OutputStream kgCompressionOut = 
keyGroupCompressionDecorator.decorateWithCompression(localStream)) {
-                                                                       String 
stateName = kvState.getKey();
                                                                        
DataOutputViewStreamWrapper kgCompressionView = new 
DataOutputViewStreamWrapper(kgCompressionOut);
-                                                                       
kgCompressionView.writeShort(kVStateToId.get(stateName));
+                                                                       
kgCompressionView.writeShort(stateNamesToId.get(stateSnapshot.getKey()));
                                                                        
partitionedSnapshot.writeStateInKeyGroup(kgCompressionView, keyGroupId);
                                                                } // this will 
just close the outer compression stream
                                                        }
@@ -747,5 +833,80 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                public <N, V> StateTable<K, N, V> 
newStateTable(RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo) {
                        return 
snapshotStrategySynchronicityTrait.newStateTable(newMetaInfo);
                }
+
+               private void processSnapshotMetaInfoForAllStates(
+                       List<StateMetaInfoSnapshot> metaInfoSnapshots,
+                       Map<StateUID, StateSnapshot> cowStateStableSnapshots,
+                       Map<StateUID, Integer> stateNamesToId,
+                       Map<String, ? extends StateSnapshotRestore> 
registeredStates,
+                       StateMetaInfoSnapshot.BackendStateType stateType) {
+
+                       for (Map.Entry<String, ? extends StateSnapshotRestore> 
kvState : registeredStates.entrySet()) {
+                               final StateUID stateUid = 
StateUID.of(kvState.getKey(), stateType);
+                               stateNamesToId.put(stateUid, 
stateNamesToId.size());
+                               StateSnapshotRestore state = kvState.getValue();
+                               if (null != state) {
+                                       final StateSnapshot stateSnapshot = 
state.stateSnapshot();
+                                       
metaInfoSnapshots.add(stateSnapshot.getMetaInfoSnapshot());
+                                       cowStateStableSnapshots.put(stateUid, 
stateSnapshot);
+                               }
+                       }
+               }
+       }
+
+       private interface StateFactory {
+               <K, N, SV, S extends State, IS extends S> IS createState(
+                       StateDescriptor<S, SV> stateDesc,
+                       StateTable<K, N, SV> stateTable,
+                       TypeSerializer<K> keySerializer) throws Exception;
+       }
+
+       /**
+        * Unique identifier for registered state in this backend.
+        */
+       private static final class StateUID {
+
+               @Nonnull
+               private final String stateName;
+
+               @Nonnull
+               private final StateMetaInfoSnapshot.BackendStateType stateType;
+
+               StateUID(@Nonnull String stateName, @Nonnull 
StateMetaInfoSnapshot.BackendStateType stateType) {
+                       this.stateName = stateName;
+                       this.stateType = stateType;
+               }
+
+               @Nonnull
+               public String getStateName() {
+                       return stateName;
+               }
+
+               @Nonnull
+               public StateMetaInfoSnapshot.BackendStateType getStateType() {
+                       return stateType;
+               }
+
+               @Override
+               public boolean equals(Object o) {
+                       if (this == o) {
+                               return true;
+                       }
+                       if (o == null || getClass() != o.getClass()) {
+                               return false;
+                       }
+                       StateUID uid = (StateUID) o;
+                       return Objects.equals(getStateName(), 
uid.getStateName()) &&
+                               getStateType() == uid.getStateType();
+               }
+
+               @Override
+               public int hashCode() {
+                       return Objects.hash(getStateName(), getStateType());
+               }
+
+               public static StateUID of(@Nonnull String stateName, @Nonnull 
StateMetaInfoSnapshot.BackendStateType stateType) {
+                       return new StateUID(stateName, stateType);
+               }
        }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
index b0255d3..80d79ac 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
@@ -59,7 +59,7 @@ public class HeapPriorityQueueSetFactory implements 
PriorityQueueSetFactory {
                @Nonnull String stateName,
                @Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
 
-               return new HeapPriorityQueueSet<T>(
+               return new HeapPriorityQueueSet<>(
                        PriorityComparator.forPriorityComparableObjects(),
                        KeyExtractorFunction.forKeyedObjects(),
                        minimumCapacity,
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java
index b2b2843..fc1e0db 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java
@@ -89,4 +89,26 @@ public class HeapPriorityQueueSnapshotRestoreWrapper<T 
extends HeapPriorityQueue
        public HeapPriorityQueueSet<T> getPriorityQueue() {
                return priorityQueue;
        }
+
+       @Nonnull
+       public RegisteredPriorityQueueStateBackendMetaInfo<T> getMetaInfo() {
+               return metaInfo;
+       }
+
+       /**
+        * Returns a deep copy of the snapshot, where the serializer is changed 
to the given serializer.
+        */
+       public HeapPriorityQueueSnapshotRestoreWrapper<T> forUpdatedSerializer(
+               @Nonnull TypeSerializer<T> updatedSerializer) {
+
+               RegisteredPriorityQueueStateBackendMetaInfo<T> updatedMetaInfo =
+                       new 
RegisteredPriorityQueueStateBackendMetaInfo<>(metaInfo.getName(), 
updatedSerializer);
+
+               return new HeapPriorityQueueSnapshotRestoreWrapper<>(
+                       priorityQueue,
+                       updatedMetaInfo,
+                       keyExtractorFunction,
+                       localKeyGroupRange,
+                       totalKeyGroups);
+       }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java
index 510d277..935ebb6 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java
@@ -347,7 +347,7 @@ public abstract class InternalPriorityQueueTestBase extends 
TestLogger {
        /**
         * Payload for usage in the test.
         */
-       protected static class TestElement implements HeapPriorityQueueElement {
+       protected static class TestElement implements HeapPriorityQueueElement, 
Keyed<Long>, PriorityComparable<TestElement> {
 
                private final long key;
                private final long priority;
@@ -359,7 +359,12 @@ public abstract class InternalPriorityQueueTestBase 
extends TestLogger {
                        this.internalIndex = NOT_CONTAINED;
                }
 
-               public long getKey() {
+               @Override
+               public int comparePriorityTo(@Nonnull TestElement other) {
+                       return Long.compare(priority, other.priority);
+               }
+
+               public Long getKey() {
                        return key;
                }
 
@@ -386,8 +391,8 @@ public abstract class InternalPriorityQueueTestBase extends 
TestLogger {
                                return false;
                        }
                        TestElement that = (TestElement) o;
-                       return getKey() == that.getKey() &&
-                               getPriority() == that.getPriority();
+                       return key == that.key &&
+                               priority == that.priority;
                }
 
                @Override
@@ -414,9 +419,11 @@ public abstract class InternalPriorityQueueTestBase 
extends TestLogger {
         */
        protected static class TestElementSerializer extends 
TypeSerializer<TestElement> {
 
+               private static final int REVISION = 1;
+
                public static final TestElementSerializer INSTANCE = new 
TestElementSerializer();
 
-               private TestElementSerializer() {
+               protected TestElementSerializer() {
                }
 
                @Override
@@ -489,14 +496,62 @@ public abstract class InternalPriorityQueueTestBase 
extends TestLogger {
                        return 4711;
                }
 
+               protected int getRevision() {
+                       return REVISION;
+               }
+
                @Override
                public TypeSerializerConfigSnapshot snapshotConfiguration() {
-                       throw new UnsupportedOperationException();
+                       return new Snapshot(getRevision());
                }
 
                @Override
                public CompatibilityResult<TestElement> 
ensureCompatibility(TypeSerializerConfigSnapshot configSnapshot) {
-                       throw new UnsupportedOperationException();
+                       return (configSnapshot instanceof Snapshot) && 
((Snapshot) configSnapshot).revision <= getRevision() ?
+                               CompatibilityResult.compatible() : 
CompatibilityResult.requiresMigration();
+               }
+
+               public static class Snapshot extends 
TypeSerializerConfigSnapshot {
+
+                       private int revision;
+
+                       public Snapshot() {
+                       }
+
+                       public Snapshot(int revision) {
+                               this.revision = revision;
+                       }
+
+                       @Override
+                       public boolean equals(Object obj) {
+                               return obj instanceof Snapshot && revision == 
((Snapshot) obj).revision;
+                       }
+
+                       @Override
+                       public int hashCode() {
+                               return revision;
+                       }
+
+                       @Override
+                       public int getVersion() {
+                               return 0;
+                       }
+
+                       public int getRevision() {
+                               return revision;
+                       }
+
+                       @Override
+                       public void write(DataOutputView out) throws 
IOException {
+                               super.write(out);
+                               out.writeInt(revision);
+                       }
+
+                       @Override
+                       public void read(DataInputView in) throws IOException {
+                               super.read(in);
+                               this.revision = in.readInt();
+                       }
                }
        }
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
index 0ba4c33..215d7d3 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
@@ -33,6 +33,7 @@ import 
org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.testutils.ArtificialCNFExceptionThrowingClassLoader;
 import org.apache.flink.util.FutureUtil;
+
 import org.junit.Assert;
 import org.junit.Ignore;
 import org.junit.Test;
@@ -154,6 +155,7 @@ public class MemoryStateBackendTest extends 
StateBackendTestBase<MemoryStateBack
        @Test
        public void 
testKeyedStateRestoreFailsIfSerializerDeserializationFails() throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                KeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                ValueStateDescriptor<String> kvId = new 
ValueStateDescriptor<>("id", String.class, null);
@@ -161,7 +163,7 @@ public class MemoryStateBackendTest extends 
StateBackendTestBase<MemoryStateBack
 
                HeapKeyedStateBackend<Integer> heapBackend = 
(HeapKeyedStateBackend<Integer>) backend;
 
-               assertEquals(0, heapBackend.numStateEntries());
+               assertEquals(0, heapBackend.numKeyValueStateEntries());
 
                ValueState<String> state = 
backend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
 
@@ -170,11 +172,13 @@ public class MemoryStateBackendTest extends 
StateBackendTestBase<MemoryStateBack
                state.update("hello");
                state.update("ciao");
 
-               KeyedStateHandle snapshot = 
runSnapshot(((HeapKeyedStateBackend<Integer>) backend).snapshot(
-                       682375462378L,
-                       2,
-                       streamFactory,
-                       CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot = runSnapshot(
+                       ((HeapKeyedStateBackend<Integer>) backend).snapshot(
+                               682375462378L,
+                               2,
+                               streamFactory,
+                               
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                backend.dispose();
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index bfdc05d..059a706 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -50,7 +50,10 @@ import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.api.java.typeutils.runtime.PojoSerializer;
 import org.apache.flink.api.java.typeutils.runtime.kryo.JavaSerializer;
 import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
+import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
 import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.core.testutils.CheckedThread;
@@ -77,6 +80,7 @@ import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.runtime.util.BlockerCheckpointStreamFactory;
 import org.apache.flink.testutils.ArtificialCNFExceptionThrowingClassLoader;
 import org.apache.flink.types.IntValue;
+import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.IOUtils;
 import org.apache.flink.util.StateMigrationException;
 import org.apache.flink.util.TestLogger;
@@ -87,6 +91,7 @@ import com.esotericsoftware.kryo.Kryo;
 import com.esotericsoftware.kryo.io.Input;
 import com.esotericsoftware.kryo.io.Output;
 import org.apache.commons.io.output.ByteArrayOutputStream;
+import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -295,6 +300,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        public void testBackendUsesRegisteredKryoDefaultSerializer() throws 
Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
                Environment env = new DummyEnvironment();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE, env);
 
                // cast because our test serializer is not typed to TestPojo
@@ -330,7 +336,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                try {
                        // backends that lazily serializes (such as memory 
state backend) will fail here
-                       runSnapshot(backend.snapshot(682375462378L, 2, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       runSnapshot(
+                               backend.snapshot(682375462378L, 2, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
                } catch (ExpectedKryoTestException e) {
                        numExceptions++;
                } catch (Exception e) {
@@ -350,6 +358,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        public void 
testBackendUsesRegisteredKryoDefaultSerializerUsingGetOrCreate() throws 
Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
                Environment env = new DummyEnvironment();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE, env);
 
                // cast because our test serializer is not typed to TestPojo
@@ -390,7 +399,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                try {
                        // backends that lazily serializes (such as memory 
state backend) will fail here
-                       runSnapshot(backend.snapshot(682375462378L, 2, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       runSnapshot(
+                               backend.snapshot(682375462378L, 2, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
                } catch (ExpectedKryoTestException e) {
                        numExceptions++;
                } catch (Exception e) {
@@ -409,8 +420,8 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        public void testBackendUsesRegisteredKryoSerializer() throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
                Environment env = new DummyEnvironment();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE, env);
-
                env.getExecutionConfig()
                                .registerTypeWithKryoSerializer(TestPojo.class, 
ExceptionThrowingTestSerializer.class);
 
@@ -444,7 +455,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                try {
                        // backends that lazily serializes (such as memory 
state backend) will fail here
-                       runSnapshot(backend.snapshot(682375462378L, 2, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       runSnapshot(
+                               backend.snapshot(682375462378L, 2, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
                } catch (ExpectedKryoTestException e) {
                        numExceptions++;
                } catch (Exception e) {
@@ -464,6 +477,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        public void testBackendUsesRegisteredKryoSerializerUsingGetOrCreate() 
throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
                Environment env = new DummyEnvironment();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE, env);
 
                
env.getExecutionConfig().registerTypeWithKryoSerializer(TestPojo.class, 
ExceptionThrowingTestSerializer.class);
@@ -500,7 +514,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                try {
                        // backends that lazily serializes (such as memory 
state backend) will fail here
-                       runSnapshot(backend.snapshot(682375462378L, 2, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       runSnapshot(
+                               backend.snapshot(682375462378L, 2, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
                } catch (ExpectedKryoTestException e) {
                        numExceptions++;
                } catch (Exception e) {
@@ -528,6 +544,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        public void testKryoRegisteringRestoreResilienceWithRegisteredType() 
throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
                Environment env = new DummyEnvironment();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE, env);
 
                TypeInformation<TestPojo> pojoType = new 
GenericTypeInfo<>(TestPojo.class);
@@ -548,11 +565,13 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                backend.setCurrentKey(2);
                state.update(new TestPojo("u2", 2));
 
-               KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
+               KeyedStateHandle snapshot = runSnapshot(
+                       backend.snapshot(
                                682375462378L,
                                2,
                                streamFactory,
-                               
CheckpointOptions.forCheckpointWithDefaultLocation()));
+                               
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                backend.dispose();
 
@@ -617,9 +636,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                                682375462378L,
                                2,
                                streamFactory,
-                               
CheckpointOptions.forCheckpointWithDefaultLocation()));
+                               
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
-                       snapshot.registerSharedStates(sharedStateRegistry);
                        backend.dispose();
 
                        // ========== restore snapshot - should use default 
serializer (ONLY SERIALIZATION) ==========
@@ -639,13 +658,14 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        // update to test state backends that eagerly 
serialize, such as RocksDB
                        state.update(new TestPojo("u1", 11));
 
-                       KeyedStateHandle snapshot2 = 
runSnapshot(backend.snapshot(
-                               682375462378L,
-                               2,
-                               streamFactory,
-                               
CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       KeyedStateHandle snapshot2 = runSnapshot(
+                               backend.snapshot(
+                                       682375462378L,
+                                       2,
+                                       streamFactory,
+                                       
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
-                       snapshot2.registerSharedStates(sharedStateRegistry);
                        snapshot.discardState();
 
                        backend.dispose();
@@ -715,13 +735,14 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        backend.setCurrentKey(2);
                        state.update(new TestPojo("u2", 2));
 
-                       KeyedStateHandle snapshot = 
runSnapshot(backend.snapshot(
-                               682375462378L,
-                               2,
-                               streamFactory,
-                               
CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       KeyedStateHandle snapshot = runSnapshot(
+                               backend.snapshot(
+                                       682375462378L,
+                                       2,
+                                       streamFactory,
+                                       
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
-                       snapshot.registerSharedStates(sharedStateRegistry);
                        backend.dispose();
 
                        // ========== restore snapshot - should use specific 
serializer (ONLY SERIALIZATION) ==========
@@ -740,13 +761,13 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        // update to test state backends that eagerly 
serialize, such as RocksDB
                        state.update(new TestPojo("u1", 11));
 
-                       KeyedStateHandle snapshot2 = 
runSnapshot(backend.snapshot(
-                               682375462378L,
-                               2,
-                               streamFactory,
-                               
CheckpointOptions.forCheckpointWithDefaultLocation()));
-
-                       snapshot2.registerSharedStates(sharedStateRegistry);
+                       KeyedStateHandle snapshot2 = runSnapshot(
+                               backend.snapshot(
+                                       682375462378L,
+                                       2,
+                                       streamFactory,
+                                       
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
                        snapshot.discardState();
 
@@ -783,6 +804,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        public void testKryoRestoreResilienceWithDifferentRegistrationOrder() 
throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
                Environment env = new DummyEnvironment();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
 
                // register A first then B
                
env.getExecutionConfig().registerKryoType(TestNestedPojoClassA.class);
@@ -790,83 +812,91 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE, env);
 
-               TypeInformation<TestPojo> pojoType = new 
GenericTypeInfo<>(TestPojo.class);
+               try {
 
-               // make sure that we are in fact using the KryoSerializer
-               assertTrue(pojoType.createSerializer(env.getExecutionConfig()) 
instanceof KryoSerializer);
+                       TypeInformation<TestPojo> pojoType = new 
GenericTypeInfo<>(TestPojo.class);
 
-               ValueStateDescriptor<TestPojo> kvId = new 
ValueStateDescriptor<>("id", pojoType);
-               ValueState<TestPojo> state = 
backend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
+                       // make sure that we are in fact using the 
KryoSerializer
+                       
assertTrue(pojoType.createSerializer(env.getExecutionConfig()) instanceof 
KryoSerializer);
 
-               // access the internal state representation to retrieve the 
original Kryo registration ids;
-               // these will be later used to check that on restore, the new 
Kryo serializer has reconfigured itself to
-               // have identical mappings
-               InternalKvState internalKvState = (InternalKvState) state;
-               KryoSerializer<TestPojo> kryoSerializer = 
(KryoSerializer<TestPojo>) internalKvState.getValueSerializer();
-               int mainPojoClassRegistrationId = 
kryoSerializer.getKryo().getRegistration(TestPojo.class).getId();
-               int nestedPojoClassARegistrationId = 
kryoSerializer.getKryo().getRegistration(TestNestedPojoClassA.class).getId();
-               int nestedPojoClassBRegistrationId = 
kryoSerializer.getKryo().getRegistration(TestNestedPojoClassB.class).getId();
+                       ValueStateDescriptor<TestPojo> kvId = new 
ValueStateDescriptor<>("id", pojoType);
+                       ValueState<TestPojo> state = 
backend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
 
-               // ============== create snapshot of current configuration 
==============
+                       // access the internal state representation to retrieve 
the original Kryo registration ids;
+                       // these will be later used to check that on restore, 
the new Kryo serializer has reconfigured itself to
+                       // have identical mappings
+                       InternalKvState internalKvState = (InternalKvState) 
state;
+                       KryoSerializer<TestPojo> kryoSerializer = 
(KryoSerializer<TestPojo>) internalKvState.getValueSerializer();
+                       int mainPojoClassRegistrationId = 
kryoSerializer.getKryo().getRegistration(TestPojo.class).getId();
+                       int nestedPojoClassARegistrationId = 
kryoSerializer.getKryo().getRegistration(TestNestedPojoClassA.class).getId();
+                       int nestedPojoClassBRegistrationId = 
kryoSerializer.getKryo().getRegistration(TestNestedPojoClassB.class).getId();
 
-               // make some more modifications
-               backend.setCurrentKey(1);
-               state.update(new TestPojo("u1", 1, new 
TestNestedPojoClassA(1.0, 2), new TestNestedPojoClassB(2.3, "foo")));
+                       // ============== create snapshot of current 
configuration ==============
 
-               backend.setCurrentKey(2);
-               state.update(new TestPojo("u2", 2, new 
TestNestedPojoClassA(2.0, 5), new TestNestedPojoClassB(3.1, "bar")));
+                       // make some more modifications
+                       backend.setCurrentKey(1);
+                       state.update(new TestPojo("u1", 1, new 
TestNestedPojoClassA(1.0, 2), new TestNestedPojoClassB(2.3, "foo")));
 
-               KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
-                       682375462378L,
-                       2,
-                       streamFactory,
-                       CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       backend.setCurrentKey(2);
+                       state.update(new TestPojo("u2", 2, new 
TestNestedPojoClassA(2.0, 5), new TestNestedPojoClassB(3.1, "bar")));
 
-               backend.dispose();
+                       KeyedStateHandle snapshot = runSnapshot(
+                               backend.snapshot(
+                                       682375462378L,
+                                       2,
+                                       streamFactory,
+                                       
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
-               // ========== restore snapshot, with a different registration 
order in the configuration ==========
+                       backend.dispose();
 
-               env = new DummyEnvironment();
+                       // ========== restore snapshot, with a different 
registration order in the configuration ==========
 
-               
env.getExecutionConfig().registerKryoType(TestNestedPojoClassB.class); // this 
time register B first
-               
env.getExecutionConfig().registerKryoType(TestNestedPojoClassA.class);
+                       env = new DummyEnvironment();
 
-               backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, 
env);
+                       
env.getExecutionConfig().registerKryoType(TestNestedPojoClassB.class); // this 
time register B first
+                       
env.getExecutionConfig().registerKryoType(TestNestedPojoClassA.class);
 
-               snapshot.discardState();
+                       backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot, env);
 
-               // re-initialize to ensure that we create the KryoSerializer 
from scratch, otherwise
-               // initializeSerializerUnlessSet would not pick up our new 
config
-               kvId = new ValueStateDescriptor<>("id", pojoType);
-               state = backend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
+                       // re-initialize to ensure that we create the 
KryoSerializer from scratch, otherwise
+                       // initializeSerializerUnlessSet would not pick up our 
new config
+                       kvId = new ValueStateDescriptor<>("id", pojoType);
+                       state = 
backend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
 
-               // verify that on restore, the serializer that the state handle 
uses has reconfigured itself to have
-               // identical Kryo registration ids compared to the previous 
execution
-               internalKvState = (InternalKvState) state;
-               kryoSerializer = (KryoSerializer<TestPojo>) 
internalKvState.getValueSerializer();
-               assertEquals(mainPojoClassRegistrationId, 
kryoSerializer.getKryo().getRegistration(TestPojo.class).getId());
-               assertEquals(nestedPojoClassARegistrationId, 
kryoSerializer.getKryo().getRegistration(TestNestedPojoClassA.class).getId());
-               assertEquals(nestedPojoClassBRegistrationId, 
kryoSerializer.getKryo().getRegistration(TestNestedPojoClassB.class).getId());
+                       // verify that on restore, the serializer that the 
state handle uses has reconfigured itself to have
+                       // identical Kryo registration ids compared to the 
previous execution
+                       internalKvState = (InternalKvState) state;
+                       kryoSerializer = (KryoSerializer<TestPojo>) 
internalKvState.getValueSerializer();
+                       assertEquals(mainPojoClassRegistrationId, 
kryoSerializer.getKryo().getRegistration(TestPojo.class).getId());
+                       assertEquals(nestedPojoClassARegistrationId, 
kryoSerializer.getKryo().getRegistration(TestNestedPojoClassA.class).getId());
+                       assertEquals(nestedPojoClassBRegistrationId, 
kryoSerializer.getKryo().getRegistration(TestNestedPojoClassB.class).getId());
 
-               backend.setCurrentKey(1);
+                       backend.setCurrentKey(1);
 
-               // update to test state backends that eagerly serialize, such 
as RocksDB
-               state.update(new TestPojo("u1", 11, new 
TestNestedPojoClassA(22.1, 12), new TestNestedPojoClassB(1.23, "foobar")));
+                       // update to test state backends that eagerly 
serialize, such as RocksDB
+                       state.update(new TestPojo("u1", 11, new 
TestNestedPojoClassA(22.1, 12), new TestNestedPojoClassB(1.23, "foobar")));
 
-               // this tests backends that lazily serialize, such as memory 
state backend
-               runSnapshot(backend.snapshot(
-                       682375462378L,
-                       2,
-                       streamFactory,
-                       CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       // this tests backends that lazily serialize, such as 
memory state backend
+                       runSnapshot(
+                               backend.snapshot(
+                                       682375462378L,
+                                       2,
+                                       streamFactory,
+                                       
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
-               backend.dispose();
+                       snapshot.discardState();
+               } finally {
+                       backend.dispose();
+               }
        }
 
        @Test
        public void testPojoRestoreResilienceWithDifferentRegistrationOrder() 
throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
                Environment env = new DummyEnvironment();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
 
                // register A first then B
                
env.getExecutionConfig().registerPojoType(TestNestedPojoClassA.class);
@@ -874,60 +904,66 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE, env);
 
-               TypeInformation<TestPojo> pojoType = 
TypeExtractor.getForClass(TestPojo.class);
+               try {
 
-               // make sure that we are in fact using the PojoSerializer
-               assertTrue(pojoType.createSerializer(env.getExecutionConfig()) 
instanceof PojoSerializer);
+                       TypeInformation<TestPojo> pojoType = 
TypeExtractor.getForClass(TestPojo.class);
 
-               ValueStateDescriptor<TestPojo> kvId = new 
ValueStateDescriptor<>("id", pojoType);
-               ValueState<TestPojo> state = 
backend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
+                       // make sure that we are in fact using the 
PojoSerializer
+                       
assertTrue(pojoType.createSerializer(env.getExecutionConfig()) instanceof 
PojoSerializer);
 
-               // ============== create snapshot of current configuration 
==============
+                       ValueStateDescriptor<TestPojo> kvId = new 
ValueStateDescriptor<>("id", pojoType);
+                       ValueState<TestPojo> state = 
backend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
 
-               // make some more modifications
-               backend.setCurrentKey(1);
-               state.update(new TestPojo("u1", 1, new 
TestNestedPojoClassA(1.0, 2), new TestNestedPojoClassB(2.3, "foo")));
+                       // ============== create snapshot of current 
configuration ==============
 
-               backend.setCurrentKey(2);
-               state.update(new TestPojo("u2", 2, new 
TestNestedPojoClassA(2.0, 5), new TestNestedPojoClassB(3.1, "bar")));
+                       // make some more modifications
+                       backend.setCurrentKey(1);
+                       state.update(new TestPojo("u1", 1, new 
TestNestedPojoClassA(1.0, 2), new TestNestedPojoClassB(2.3, "foo")));
 
-               KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
-                       682375462378L,
-                       2,
-                       streamFactory,
-                       CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       backend.setCurrentKey(2);
+                       state.update(new TestPojo("u2", 2, new 
TestNestedPojoClassA(2.0, 5), new TestNestedPojoClassB(3.1, "bar")));
 
-               backend.dispose();
+                       KeyedStateHandle snapshot = runSnapshot(
+                               backend.snapshot(
+                                       682375462378L,
+                                       2,
+                                       streamFactory,
+                                       
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
-               // ========== restore snapshot, with a different registration 
order in the configuration ==========
+                       backend.dispose();
 
-               env = new DummyEnvironment();
+                       // ========== restore snapshot, with a different 
registration order in the configuration ==========
 
-               
env.getExecutionConfig().registerPojoType(TestNestedPojoClassB.class); // this 
time register B first
-               
env.getExecutionConfig().registerPojoType(TestNestedPojoClassA.class);
+                       env = new DummyEnvironment();
 
-               backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, 
env);
+                       
env.getExecutionConfig().registerPojoType(TestNestedPojoClassB.class); // this 
time register B first
+                       
env.getExecutionConfig().registerPojoType(TestNestedPojoClassA.class);
 
-               snapshot.discardState();
+                       backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot, env);
 
-               // re-initialize to ensure that we create the PojoSerializer 
from scratch, otherwise
-               // initializeSerializerUnlessSet would not pick up our new 
config
-               kvId = new ValueStateDescriptor<>("id", pojoType);
-               state = backend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
+                       // re-initialize to ensure that we create the 
PojoSerializer from scratch, otherwise
+                       // initializeSerializerUnlessSet would not pick up our 
new config
+                       kvId = new ValueStateDescriptor<>("id", pojoType);
+                       state = 
backend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
 
-               backend.setCurrentKey(1);
+                       backend.setCurrentKey(1);
 
-               // update to test state backends that eagerly serialize, such 
as RocksDB
-               state.update(new TestPojo("u1", 11, new 
TestNestedPojoClassA(22.1, 12), new TestNestedPojoClassB(1.23, "foobar")));
+                       // update to test state backends that eagerly 
serialize, such as RocksDB
+                       state.update(new TestPojo("u1", 11, new 
TestNestedPojoClassA(22.1, 12), new TestNestedPojoClassB(1.23, "foobar")));
 
-               // this tests backends that lazily serialize, such as memory 
state backend
-               runSnapshot(backend.snapshot(
-                       682375462378L,
-                       2,
-                       streamFactory,
-                       CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       // this tests backends that lazily serialize, such as 
memory state backend
+                       runSnapshot(
+                               backend.snapshot(
+                                       682375462378L,
+                                       2,
+                                       streamFactory,
+                                       
CheckpointOptions.forCheckpointWithDefaultLocation()), sharedStateRegistry);
 
-               backend.dispose();
+                       snapshot.discardState();
+               } finally {
+                       backend.dispose();
+               }
        }
 
        @Test
@@ -957,13 +993,14 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        assertTrue(internal.getValueSerializer() instanceof 
TestReconfigurableCustomTypeSerializer);
                        assertFalse(((TestReconfigurableCustomTypeSerializer) 
internal.getValueSerializer()).isReconfigured());
 
-                       KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(
-                               682375462378L,
-                               2,
-                               streamFactory,
-                               
CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       KeyedStateHandle snapshot1 = runSnapshot(
+                               backend.snapshot(
+                                       682375462378L,
+                                       2,
+                                       streamFactory,
+                                       
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
-                       snapshot1.registerSharedStates(sharedStateRegistry);
                        backend.dispose();
 
                        // ========== restore snapshot, which should 
reconfigure the serializer, and then create a snapshot again ==========
@@ -995,13 +1032,14 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                        state.update(new 
TestCustomStateClass("new-test-message-2", "extra-message-2"));
 
-                       KeyedStateHandle snapshot2 = 
runSnapshot(backend.snapshot(
-                               682375462379L,
-                               3,
-                               streamFactory,
-                               
CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       KeyedStateHandle snapshot2 = runSnapshot(
+                               backend.snapshot(
+                                       682375462379L,
+                                       3,
+                                       streamFactory,
+                                       
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
-                       snapshot2.registerSharedStates(sharedStateRegistry);
                        snapshot1.discardState();
                        backend.dispose();
 
@@ -1055,13 +1093,14 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        backend.setCurrentKey(2);
                        state.update(new TestCustomStateClass("test-message-2", 
"this-should-be-ignored"));
 
-                       KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(
-                               682375462378L,
-                               2,
-                               streamFactory,
-                               
CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       KeyedStateHandle snapshot1 = runSnapshot(
+                               backend.snapshot(
+                                       682375462378L,
+                                       2,
+                                       streamFactory,
+                                       
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
-                       snapshot1.registerSharedStates(sharedStateRegistry);
                        backend.dispose();
 
                        // ========== restore snapshot, using the new 
serializer (that has different classname) ==========
@@ -1093,13 +1132,14 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        backend.setCurrentKey(2);
                        state.update(new 
TestCustomStateClass("new-test-message-2", "extra-message-2"));
 
-                       KeyedStateHandle snapshot2 = 
runSnapshot(backend.snapshot(
-                               682375462379L,
-                               3,
-                               streamFactory,
-                               
CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       KeyedStateHandle snapshot2 = runSnapshot(
+                               backend.snapshot(
+                                       682375462379L,
+                                       3,
+                                       streamFactory,
+                                       
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
-                       snapshot2.registerSharedStates(sharedStateRegistry);
                        snapshot1.discardState();
                } finally {
                        backend.dispose();
@@ -1107,9 +1147,105 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        }
 
        @Test
+       public void testPriorityQueueSerializerUpdates() throws Exception {
+
+               final String stateName = "test";
+               final CheckpointStreamFactory streamFactory = 
createStreamFactory();
+               final SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+
+               AbstractKeyedStateBackend<Integer> keyedBackend = 
createKeyedBackend(IntSerializer.INSTANCE);
+
+               try {
+                       
TypeSerializer<InternalPriorityQueueTestBase.TestElement> serializer =
+                               
InternalPriorityQueueTestBase.TestElementSerializer.INSTANCE;
+
+                       
KeyGroupedInternalPriorityQueue<InternalPriorityQueueTestBase.TestElement> 
priorityQueue =
+                               keyedBackend.create(stateName, serializer);
+
+                       priorityQueue.add(new 
InternalPriorityQueueTestBase.TestElement(42L, 0L));
+
+                       RunnableFuture<SnapshotResult<KeyedStateHandle>> 
snapshot =
+                               keyedBackend.snapshot(0L, 0L, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation());
+
+                       KeyedStateHandle keyedStateHandle = 
runSnapshot(snapshot, sharedStateRegistry);
+
+                       keyedBackend.dispose();
+
+                       // test restore with a modified but compatible 
serializer ---------------------------
+
+                       keyedBackend = 
restoreKeyedBackend(IntSerializer.INSTANCE, keyedStateHandle);
+
+                       serializer = new ModifiedTestElementSerializer();
+
+                       priorityQueue = keyedBackend.create(stateName, 
serializer);
+
+                       final InternalPriorityQueueTestBase.TestElement 
checkElement =
+                               new 
InternalPriorityQueueTestBase.TestElement(4711L, 1L);
+                       priorityQueue.add(checkElement);
+
+                       snapshot = keyedBackend.snapshot(1L, 1L, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation());
+
+                       keyedStateHandle = runSnapshot(snapshot, 
sharedStateRegistry);
+
+                       keyedBackend.dispose();
+
+                       // test that the modified serializer was actually used 
---------------------------
+
+                       keyedBackend = 
restoreKeyedBackend(IntSerializer.INSTANCE, keyedStateHandle);
+                       priorityQueue = keyedBackend.create(stateName, 
serializer);
+
+                       priorityQueue.poll();
+
+                       ByteArrayOutputStreamWithPos out = new 
ByteArrayOutputStreamWithPos();
+                       DataOutputViewStreamWrapper outWrapper = new 
DataOutputViewStreamWrapper(out);
+                       serializer.serialize(checkElement, outWrapper);
+                       InternalPriorityQueueTestBase.TestElement expected =
+                               serializer.deserialize(new 
DataInputViewStreamWrapper(new ByteArrayInputStreamWithPos(out.toByteArray())));
+
+                       Assert.assertEquals(
+                               expected,
+                               priorityQueue.poll());
+                       Assert.assertTrue(priorityQueue.isEmpty());
+
+                       keyedBackend.dispose();
+
+                       // test that incompatible serializer is rejected 
---------------------------
+
+                       serializer = 
InternalPriorityQueueTestBase.TestElementSerializer.INSTANCE;
+                       keyedBackend = 
restoreKeyedBackend(IntSerializer.INSTANCE, keyedStateHandle);
+
+                       try {
+                               // this is expected to fail, because the old 
and new serializer shoulbe be incompatible through
+                               // different revision numbers.
+                               keyedBackend.create("test", serializer);
+                               Assert.fail("Expected exception from 
incompatible serializer.");
+                       } catch (Exception e) {
+                               Assert.assertTrue("Exception was not caused by 
state migration: " + e,
+                                       ExceptionUtils.findThrowable(e, 
StateMigrationException.class).isPresent());
+                       }
+               } finally {
+                       keyedBackend.dispose();
+               }
+       }
+
+       public static class ModifiedTestElementSerializer extends 
InternalPriorityQueueTestBase.TestElementSerializer {
+
+               @Override
+               public void serialize(InternalPriorityQueueTestBase.TestElement 
record, DataOutputView target) throws IOException {
+                       super.serialize(new 
InternalPriorityQueueTestBase.TestElement(record.getKey() + 1, 
record.getPriority() + 1), target);
+               }
+
+               @Override
+               protected int getRevision() {
+                       return super.getRevision() + 1;
+               }
+       }
+
+       @Test
        @SuppressWarnings("unchecked")
        public void testValueState() throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                ValueStateDescriptor<String> kvId = new 
ValueStateDescriptor<>("id", String.class);
@@ -1138,7 +1274,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                assertEquals("1", getSerializedValue(kvState, 1, keySerializer, 
VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
                // draw a snapshot
-               KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot1 = runSnapshot(
+                       backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                // make some more modifications
                backend.setCurrentKey(1);
@@ -1149,7 +1287,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                state.update("u3");
 
                // draw another snapshot
-               KeyedStateHandle snapshot2 = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot2 = runSnapshot(
+                       backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                // validate the original state
                backend.setCurrentKey(1);
@@ -1320,7 +1460,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        @SuppressWarnings("unchecked")
        public void testMultipleValueStates() throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
-
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(
                                IntSerializer.INSTANCE,
                                1,
@@ -1350,7 +1490,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                // draw a snapshot
                KeyedStateHandle snapshot1 =
-                       runSnapshot(backend.snapshot(682375462378L, 2, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       runSnapshot(
+                               backend.snapshot(682375462378L, 2, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
                backend.dispose();
                backend = restoreKeyedBackend(
@@ -1394,6 +1536,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                }
 
                CheckpointStreamFactory streamFactory = createStreamFactory();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                ValueStateDescriptor<Long> kvId = new 
ValueStateDescriptor<>("id", LongSerializer.INSTANCE, 42L);
@@ -1422,7 +1565,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                assertEquals(42L, (long) state.value());
 
                // draw a snapshot
-               KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot1 = runSnapshot(
+                       backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                backend.dispose();
                backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot1);
@@ -1438,6 +1583,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        @SuppressWarnings("unchecked,rawtypes")
        public void testListState() throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                ListStateDescriptor<String> kvId = new 
ListStateDescriptor<>("id", String.class);
@@ -1470,7 +1616,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                assertEquals("1", joiner.join(getSerializedList(kvState, 1, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
                // draw a snapshot
-               KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot1 = runSnapshot(
+                       backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                // make some more modifications
                backend.setCurrentKey(1);
@@ -1483,7 +1631,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                state.add("u3");
 
                // draw another snapshot
-               KeyedStateHandle snapshot2 = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot2 = runSnapshot(
+                       backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                // validate the original state
                backend.setCurrentKey(1);
@@ -1756,7 +1906,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.clear();
 
                        // make sure all lists / maps are cleared
-                       assertThat("State backend is not empty.", 
keyedBackend.numStateEntries(), is(0));
+                       assertThat("State backend is not empty.", 
keyedBackend.numKeyValueStateEntries(), is(0));
                } finally {
                        keyedBackend.close();
                        keyedBackend.dispose();
@@ -1870,7 +2020,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.setCurrentNamespace(namespace1);
                        state.clear();
 
-                       assertThat("State backend is not empty.", 
keyedBackend.numStateEntries(), is(0));
+                       assertThat("State backend is not empty.", 
keyedBackend.numKeyValueStateEntries(), is(0));
                }
                finally {
                        keyedBackend.close();
@@ -1882,6 +2032,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        @SuppressWarnings("unchecked")
        public void testReducingState() throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                ReducingStateDescriptor<String> kvId = new 
ReducingStateDescriptor<>("id", new AppendingReduce(), String.class);
@@ -1910,7 +2061,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                assertEquals("1", getSerializedValue(kvState, 1, keySerializer, 
VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
                // draw a snapshot
-               KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot1 = runSnapshot(
+                       backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                // make some more modifications
                backend.setCurrentKey(1);
@@ -1921,7 +2074,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                state.add("u3");
 
                // draw another snapshot
-               KeyedStateHandle snapshot2 = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot2 = runSnapshot(
+                       backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                // validate the original state
                backend.setCurrentKey(1);
@@ -2019,7 +2174,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.clear();
 
                        // make sure all lists / maps are cleared
-                       assertThat("State backend is not empty.", 
keyedBackend.numStateEntries(), is(0));
+                       assertThat("State backend is not empty.", 
keyedBackend.numKeyValueStateEntries(), is(0));
                }
                finally {
                        keyedBackend.close();
@@ -2137,7 +2292,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.setCurrentNamespace(namespace1);
                        state.clear();
 
-                       assertThat("State backend is not empty.", 
keyedBackend.numStateEntries(), is(0));
+                       assertThat("State backend is not empty.", 
keyedBackend.numKeyValueStateEntries(), is(0));
                }
                finally {
                        keyedBackend.close();
@@ -2192,7 +2347,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.clear();
 
                        // make sure all lists / maps are cleared
-                       assertThat("State backend is not empty.", 
keyedBackend.numStateEntries(), is(0));
+                       assertThat("State backend is not empty.", 
keyedBackend.numKeyValueStateEntries(), is(0));
                }
                finally {
                        keyedBackend.close();
@@ -2310,7 +2465,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.setCurrentNamespace(namespace1);
                        state.clear();
 
-                       assertThat("State backend is not empty.", 
keyedBackend.numStateEntries(), is(0));
+                       assertThat("State backend is not empty.", 
keyedBackend.numKeyValueStateEntries(), is(0));
                }
                finally {
                        keyedBackend.close();
@@ -2365,7 +2520,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.clear();
 
                        // make sure all lists / maps are cleared
-                       assertThat("State backend is not empty.", 
keyedBackend.numStateEntries(), is(0));
+                       assertThat("State backend is not empty.", 
keyedBackend.numKeyValueStateEntries(), is(0));
                }
                finally {
                        keyedBackend.close();
@@ -2483,7 +2638,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.setCurrentNamespace(namespace1);
                        state.clear();
 
-                       assertThat("State backend is not empty.", 
keyedBackend.numStateEntries(), is(0));
+                       assertThat("State backend is not empty.", 
keyedBackend.numKeyValueStateEntries(), is(0));
                }
                finally {
                        keyedBackend.close();
@@ -2495,6 +2650,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        @SuppressWarnings("unchecked,rawtypes")
        public void testFoldingState() throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                FoldingStateDescriptor<Integer, String> kvId = new 
FoldingStateDescriptor<>("id",
@@ -2526,7 +2682,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                assertEquals("Fold-Initial:,1", getSerializedValue(kvState, 1, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
                // draw a snapshot
-               KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot1 = runSnapshot(
+                       backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                // make some more modifications
                backend.setCurrentKey(1);
@@ -2538,7 +2696,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                state.add(103);
 
                // draw another snapshot
-               KeyedStateHandle snapshot2 = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot2 = runSnapshot(
+                       backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                // validate the original state
                backend.setCurrentKey(1);
@@ -2594,6 +2754,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        @SuppressWarnings("unchecked,rawtypes")
        public void testMapState() throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<String> backend = 
createKeyedBackend(StringSerializer.INSTANCE);
 
                MapStateDescriptor<Integer, String> kvId = new 
MapStateDescriptor<>("id", Integer.class, String.class);
@@ -2633,7 +2794,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        getSerializedMap(kvState, "11", keySerializer, 
VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, 
userValueSerializer));
 
                // draw a snapshot
-               KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot1 = runSnapshot(
+                       backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                // make some more modifications
                backend.setCurrentKey("1");
@@ -2645,7 +2808,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                state.putAll(new HashMap<Integer, String>() {{ put(1031, 
"1031"); put(1032, "1032"); }});
 
                // draw another snapshot
-               KeyedStateHandle snapshot2 = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot2 = runSnapshot(
+                       backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                // validate the original state
                backend.setCurrentKey("1");
@@ -2920,6 +3085,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                final int MAX_PARALLELISM = 10;
 
                CheckpointStreamFactory streamFactory = createStreamFactory();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                final AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(
                                IntSerializer.INSTANCE,
                                MAX_PARALLELISM,
@@ -2951,7 +3117,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                state.update("ShouldBeInSecondHalf");
 
 
-               KeyedStateHandle snapshot = runSnapshot(backend.snapshot(0, 0, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot = runSnapshot(
+                       backend.snapshot(0, 0, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                List<KeyedStateHandle> firstHalfKeyGroupStates = 
StateAssignmentOperation.getKeyedStateHandles(
                                Collections.singletonList(snapshot),
@@ -3004,6 +3172,8 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        public void testRestoreWithWrongKeySerializer() throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
 
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+
                // use an IntSerializer at first
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
@@ -3018,7 +3188,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                state.update("2");
 
                // draw a snapshot
-               KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot1 = runSnapshot(
+                       backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                backend.dispose();
 
@@ -3036,6 +3208,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        @SuppressWarnings("unchecked")
        public void testValueStateRestoreWithWrongSerializers() throws 
Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                try {
@@ -3049,7 +3222,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.update("2");
 
                        // draw a snapshot
-                       KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       KeyedStateHandle snapshot1 = runSnapshot(
+                               backend.snapshot(682375462378L, 2, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
                        backend.dispose();
                        // restore the first snapshot and validate it
@@ -3080,6 +3255,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        @SuppressWarnings("unchecked")
        public void testListStateRestoreWithWrongSerializers() throws Exception 
{
                CheckpointStreamFactory streamFactory = createStreamFactory();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                try {
@@ -3092,7 +3268,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.add("2");
 
                        // draw a snapshot
-                       KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       KeyedStateHandle snapshot1 = runSnapshot(
+                               backend.snapshot(682375462378L, 2, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
                        backend.dispose();
                        // restore the first snapshot and validate it
@@ -3123,6 +3301,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        @SuppressWarnings("unchecked")
        public void testReducingStateRestoreWithWrongSerializers() throws 
Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                try {
@@ -3137,7 +3316,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.add("2");
 
                        // draw a snapshot
-                       KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       KeyedStateHandle snapshot1 = runSnapshot(
+                               backend.snapshot(682375462378L, 2, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
                        backend.dispose();
                        // restore the first snapshot and validate it
@@ -3168,6 +3349,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        @SuppressWarnings("unchecked")
        public void testMapStateRestoreWithWrongSerializers() throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                try {
@@ -3180,7 +3362,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        state.put("2", "Second");
 
                        // draw a snapshot
-                       KeyedStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+                       KeyedStateHandle snapshot1 = runSnapshot(
+                               backend.snapshot(682375462378L, 2, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry);
 
                        backend.dispose();
                        // restore the first snapshot and validate it
@@ -3421,6 +3605,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                KvStateRegistry registry = env.getKvStateRegistry();
 
                CheckpointStreamFactory streamFactory = createStreamFactory();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE, env);
                KeyGroupRange expectedKeyGroupRange = 
backend.getKeyGroupRange();
 
@@ -3439,7 +3624,9 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                                eq(env.getJobID()), eq(env.getJobVertexId()), 
eq(expectedKeyGroupRange), eq("banana"), any(KvStateID.class));
 
 
-               KeyedStateHandle snapshot = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()));
+               KeyedStateHandle snapshot = runSnapshot(
+                       backend.snapshot(682375462379L, 4, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
+                       sharedStateRegistry);
 
                backend.dispose();
 
@@ -3465,13 +3652,16 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                try {
                        CheckpointStreamFactory streamFactory = 
createStreamFactory();
+                       SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                        AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                        ListStateDescriptor<String> kvId = new 
ListStateDescriptor<>("id", String.class);
 
                        // draw a snapshot
                        KeyedStateHandle snapshot =
-                               runSnapshot(backend.snapshot(682375462379L, 1, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+                               runSnapshot(
+                                       backend.snapshot(682375462379L, 1, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+                                       sharedStateRegistry);
                        assertNull(snapshot);
                        backend.dispose();
 
@@ -3491,7 +3681,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                ValueStateDescriptor<String> kvId = new 
ValueStateDescriptor<>("id", String.class);
 
-               assertEquals(0, backend.numStateEntries());
+               assertEquals(0, backend.numKeyValueStateEntries());
 
                ValueState<String> state = 
backend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
 
@@ -3499,22 +3689,22 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                state.update("hello");
                state.update("ciao");
 
-               assertEquals(1, backend.numStateEntries());
+               assertEquals(1, backend.numKeyValueStateEntries());
 
                backend.setCurrentKey(42);
                state.update("foo");
 
-               assertEquals(2, backend.numStateEntries());
+               assertEquals(2, backend.numKeyValueStateEntries());
 
                backend.setCurrentKey(0);
                state.clear();
 
-               assertEquals(1, backend.numStateEntries());
+               assertEquals(1, backend.numKeyValueStateEntries());
 
                backend.setCurrentKey(42);
                state.clear();
 
-               assertEquals(0, backend.numStateEntries());
+               assertEquals(0, backend.numKeyValueStateEntries());
 
                backend.dispose();
        }
@@ -4048,14 +4238,19 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        }
 
        protected KeyedStateHandle runSnapshot(
-               RunnableFuture<SnapshotResult<KeyedStateHandle>> 
snapshotRunnableFuture) throws Exception {
+               RunnableFuture<SnapshotResult<KeyedStateHandle>> 
snapshotRunnableFuture,
+               SharedStateRegistry sharedStateRegistry) throws Exception {
 
                if (!snapshotRunnableFuture.isDone()) {
                        snapshotRunnableFuture.run();
                }
 
                SnapshotResult<KeyedStateHandle> snapshotResult = 
snapshotRunnableFuture.get();
-               return snapshotResult.getJobManagerOwnedSnapshot();
+               KeyedStateHandle jobManagerOwnedSnapshot = 
snapshotResult.getJobManagerOwnedSnapshot();
+               if (jobManagerOwnedSnapshot != null) {
+                       
jobManagerOwnedSnapshot.registerSharedStates(sharedStateRegistry);
+               }
+               return jobManagerOwnedSnapshot;
        }
 
        public static class TestPojo implements Serializable {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
index 7b8d69f..1ca3e80 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
@@ -235,7 +235,7 @@ public class HeapKeyedStateBackendSnapshotMigrationTest 
extends HeapStateBackend
 
                        InternalListState<String, Integer, Long> state = 
keyedBackend.createInternalState(IntSerializer.INSTANCE, stateDescr);
 
-                       assertEquals(7, keyedBackend.numStateEntries());
+                       assertEquals(7, keyedBackend.numKeyValueStateEntries());
 
                        keyedBackend.setCurrentKey("abc");
                        state.setCurrentNamespace(namespace1);
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
index 9e9328b..805ae1c 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
@@ -109,7 +109,7 @@ public class MockKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
        }
 
        @Override
-       public int numStateEntries() {
+       public int numKeyValueStateEntries() {
                int count = 0;
                for (String state : stateValues.keySet()) {
                        for (K key : stateValues.get(state).keySet()) {
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 4af5a27..7ead620 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -27,6 +27,7 @@ import 
org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.CompatibilityResult;
 import org.apache.flink.api.common.typeutils.CompatibilityUtil;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.UnloadableDummyTypeSerializer;
@@ -1319,7 +1320,6 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                RegisteredKeyValueStateBackendMetaInfo<N, S> newMetaInfo;
                if (stateInfo != null) {
 
-                       @SuppressWarnings("unchecked")
                        StateMetaInfoSnapshot restoredMetaInfoSnapshot = 
restoredKvStateMetaInfos.get(stateDesc.getName());
 
                        Preconditions.checkState(
@@ -1398,7 +1398,7 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
        @VisibleForTesting
        @SuppressWarnings("unchecked")
        @Override
-       public int numStateEntries() {
+       public int numKeyValueStateEntries() {
                int count = 0;
 
                for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> 
column : kvStateInformation.values()) {
@@ -2668,10 +2668,10 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                public <T extends HeapPriorityQueueElement & PriorityComparable 
& Keyed> KeyGroupedInternalPriorityQueue<T>
                create(@Nonnull String stateName, @Nonnull TypeSerializer<T> 
byteOrderedElementSerializer) {
 
-                       final Tuple2<ColumnFamilyHandle, 
RegisteredStateMetaInfoBase> entry =
+                       final Tuple2<ColumnFamilyHandle, 
RegisteredStateMetaInfoBase> metaInfoTuple =
                                tryRegisterPriorityQueueMetaInfo(stateName, 
byteOrderedElementSerializer);
 
-                       final ColumnFamilyHandle columnFamilyHandle = entry.f0;
+                       final ColumnFamilyHandle columnFamilyHandle = 
metaInfoTuple.f0;
 
                        return new KeyGroupPartitionedPriorityQueue<>(
                                KeyExtractorFunction.forKeyedObjects(),
@@ -2708,20 +2708,51 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                @Nonnull String stateName,
                @Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
 
-               Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> entry =
+               Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> 
metaInfoTuple =
                        kvStateInformation.get(stateName);
 
-               if (entry == null) {
+               if (metaInfoTuple == null) {
+                       final ColumnFamilyHandle columnFamilyHandle = 
createColumnFamily(stateName);
+
                        RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo 
=
                                new 
RegisteredPriorityQueueStateBackendMetaInfo<>(stateName, 
byteOrderedElementSerializer);
 
-                       final ColumnFamilyHandle columnFamilyHandle = 
createColumnFamily(stateName);
+                       metaInfoTuple = new Tuple2<>(columnFamilyHandle, 
metaInfo);
+                       kvStateInformation.put(stateName, metaInfoTuple);
+               } else {
+                       // TODO we implement the simple way of supporting the 
current functionality, mimicking keyed state
+                       // because this should be reworked in FLINK-9376 and 
then we should have a common algorithm over
+                       // StateMetaInfoSnapshot that avoids this code 
duplication.
+                       StateMetaInfoSnapshot restoredMetaInfoSnapshot = 
restoredKvStateMetaInfos.get(stateName);
 
-                       entry = new Tuple2<>(columnFamilyHandle, metaInfo);
-                       kvStateInformation.put(stateName, entry);
+                       Preconditions.checkState(
+                               restoredMetaInfoSnapshot != null,
+                               "Requested to check compatibility of a restored 
RegisteredKeyedBackendStateMetaInfo," +
+                                       " but its corresponding restored 
snapshot cannot be found.");
+
+                       StateMetaInfoSnapshot.CommonSerializerKeys 
serializerKey =
+                               
StateMetaInfoSnapshot.CommonSerializerKeys.VALUE_SERIALIZER;
+
+                       TypeSerializer<?> metaInfoTypeSerializer = 
restoredMetaInfoSnapshot.getTypeSerializer(serializerKey);
+
+                       if (metaInfoTypeSerializer != 
byteOrderedElementSerializer) {
+                               CompatibilityResult<T> compatibilityResult = 
CompatibilityUtil.resolveCompatibilityResult(
+                                       metaInfoTypeSerializer,
+                                       null,
+                                       
restoredMetaInfoSnapshot.getTypeSerializerConfigSnapshot(serializerKey),
+                                       byteOrderedElementSerializer);
+
+                               if (compatibilityResult.isRequiresMigration()) {
+                                       throw new 
FlinkRuntimeException(StateMigrationException.notSupported());
+                               }
+
+                               // update meta info with new serializer
+                               metaInfoTuple.f1 =
+                                       new 
RegisteredPriorityQueueStateBackendMetaInfo<>(stateName, 
byteOrderedElementSerializer);
+                       }
                }
 
-               return entry;
+               return metaInfoTuple;
        }
 
        @Override
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
index 6b254ce..0ea0d3f 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.execution.Environment;
@@ -124,6 +125,11 @@ public class RocksDBStateBackendTest extends 
StateBackendTestBase<RocksDBStateBa
                dbPath = tempFolder.newFolder().getAbsolutePath();
                String checkpointPath = 
tempFolder.newFolder().toURI().toString();
                RocksDBStateBackend backend = new RocksDBStateBackend(new 
FsStateBackend(checkpointPath), enableIncrementalCheckpointing);
+               Configuration configuration = new Configuration();
+               configuration.setString(
+                       RocksDBOptions.TIMER_SERVICE_FACTORY,
+                       
RocksDBStateBackend.PriorityQueueStateType.ROCKSDB.toString());
+               backend = backend.configure(configuration);
                backend.setDbStoragePath(dbPath);
                return backend;
        }
diff --git a/flink-streaming-java/pom.xml b/flink-streaming-java/pom.xml
index 02da827..e64ed48 100644
--- a/flink-streaming-java/pom.xml
+++ b/flink-streaming-java/pom.xml
@@ -42,6 +42,8 @@ under the License.
                        <groupId>org.apache.flink</groupId>
                        <artifactId>flink-core</artifactId>
                        <version>${project.version}</version>
+                       <scope>test</scope>
+                       <type>test-jar</type>
                </dependency>
 
                <dependency>
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
index b54a1a9..ff48c3f 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
@@ -46,12 +46,13 @@ import java.util.Map;
 @Internal
 public class InternalTimeServiceManager<K> {
 
-       //TODO guard these constants with a test
-       private static final String TIMER_STATE_PREFIX = "_timer_state";
-       private static final String PROCESSING_TIMER_PREFIX = 
TIMER_STATE_PREFIX + "/processing_";
-       private static final String EVENT_TIMER_PREFIX = TIMER_STATE_PREFIX + 
"/event_";
+       @VisibleForTesting
+       static final String TIMER_STATE_PREFIX = "_timer_state";
+       @VisibleForTesting
+       static final String PROCESSING_TIMER_PREFIX = TIMER_STATE_PREFIX + 
"/processing_";
+       @VisibleForTesting
+       static final String EVENT_TIMER_PREFIX = TIMER_STATE_PREFIX + "/event_";
 
-       private final int totalKeyGroups;
        private final KeyGroupRange localKeyGroupRange;
        private final KeyContext keyContext;
 
@@ -63,14 +64,11 @@ public class InternalTimeServiceManager<K> {
        private final boolean useLegacySynchronousSnapshots;
 
        InternalTimeServiceManager(
-               int totalKeyGroups,
                KeyGroupRange localKeyGroupRange,
                KeyContext keyContext,
                PriorityQueueSetFactory priorityQueueSetFactory,
                ProcessingTimeService processingTimeService, boolean 
useLegacySynchronousSnapshots) {
 
-               Preconditions.checkArgument(totalKeyGroups > 0);
-               this.totalKeyGroups = totalKeyGroups;
                this.localKeyGroupRange = 
Preconditions.checkNotNull(localKeyGroupRange);
                this.priorityQueueSetFactory = 
Preconditions.checkNotNull(priorityQueueSetFactory);
                this.keyContext = Preconditions.checkNotNull(keyContext);
@@ -155,10 +153,6 @@ public class InternalTimeServiceManager<K> {
                serializationProxy.read(stream);
        }
 
-       public boolean isUseLegacySynchronousSnapshots() {
-               return useLegacySynchronousSnapshots;
-       }
-
        ////////////////////                    Methods used ONLY IN TESTS      
                        ////////////////////
 
        @VisibleForTesting
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
index a6bee4c..64af993 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
@@ -205,7 +205,6 @@ public class StreamTaskStateInitializerImpl implements 
StreamTaskStateInitialize
                final KeyGroupRange keyGroupRange = 
keyedStatedBackend.getKeyGroupRange();
 
                final InternalTimeServiceManager<K> timeServiceManager = new 
InternalTimeServiceManager<>(
-                       keyedStatedBackend.getNumberOfKeyGroups(),
                        keyGroupRange,
                        keyContext,
                        keyedStatedBackend,
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java
index 73f42ef..a83cc3a 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java
@@ -19,9 +19,12 @@
 package org.apache.flink.streaming.api.operators;
 
 import org.apache.flink.api.common.typeutils.CompatibilityResult;
+import org.apache.flink.api.common.typeutils.CompatibilityUtil;
 import 
org.apache.flink.api.common.typeutils.CompositeTypeSerializerConfigSnapshot;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot;
+import org.apache.flink.api.common.typeutils.UnloadableDummyTypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.util.MathUtils;
@@ -29,6 +32,7 @@ import org.apache.flink.util.MathUtils;
 import javax.annotation.Nonnull;
 
 import java.io.IOException;
+import java.util.List;
 import java.util.Objects;
 
 /**
@@ -42,6 +46,9 @@ public class TimerSerializer<K, N> extends 
TypeSerializer<TimerHeapInternalTimer
 
        private static final long serialVersionUID = 1L;
 
+       private static final int KEY_SERIALIZER_SNAPSHOT_INDEX = 0;
+       private static final int NAMESPACE_SERIALIZER_SNAPSHOT_INDEX = 1;
+
        /** Serializer for the key. */
        @Nonnull
        private final TypeSerializer<K> keySerializer;
@@ -208,8 +215,35 @@ public class TimerSerializer<K, N> extends 
TypeSerializer<TimerHeapInternalTimer
        @Override
        public CompatibilityResult<TimerHeapInternalTimer<K, N>> 
ensureCompatibility(
                TypeSerializerConfigSnapshot configSnapshot) {
-               //TODO this is just a mock (assuming no serializer updates) for 
now and needs a proper implementation! change this before release.
-               return CompatibilityResult.compatible();
+
+               if (configSnapshot instanceof TimerSerializerConfigSnapshot) {
+                       List<Tuple2<TypeSerializer<?>, 
TypeSerializerConfigSnapshot>> previousSerializersAndConfigs =
+                               ((TimerSerializerConfigSnapshot) 
configSnapshot).getNestedSerializersAndConfigs();
+
+                       if (previousSerializersAndConfigs.size() == 2) {
+                               Tuple2<TypeSerializer<?>, 
TypeSerializerConfigSnapshot> keySerializerAndSnapshot =
+                                       
previousSerializersAndConfigs.get(KEY_SERIALIZER_SNAPSHOT_INDEX);
+                               Tuple2<TypeSerializer<?>, 
TypeSerializerConfigSnapshot> namespaceSerializerAndSnapshot =
+                                       
previousSerializersAndConfigs.get(NAMESPACE_SERIALIZER_SNAPSHOT_INDEX);
+                               CompatibilityResult<K> keyCompatibilityResult = 
CompatibilityUtil.resolveCompatibilityResult(
+                                       keySerializerAndSnapshot.f0,
+                                       UnloadableDummyTypeSerializer.class,
+                                       keySerializerAndSnapshot.f1,
+                                       keySerializer);
+
+                               CompatibilityResult<N> 
namespaceCompatibilityResult = CompatibilityUtil.resolveCompatibilityResult(
+                                       namespaceSerializerAndSnapshot.f0,
+                                       UnloadableDummyTypeSerializer.class,
+                                       namespaceSerializerAndSnapshot.f1,
+                                       namespaceSerializer);
+
+                               if 
(!keyCompatibilityResult.isRequiresMigration()
+                                       && 
!namespaceCompatibilityResult.isRequiresMigration()) {
+                                       return CompatibilityResult.compatible();
+                               }
+                       }
+               }
+               return CompatibilityResult.requiresMigration();
        }
 
        @Nonnull
@@ -230,16 +264,29 @@ public class TimerSerializer<K, N> extends 
TypeSerializer<TimerHeapInternalTimer
         */
        public static class TimerSerializerConfigSnapshot<K, N> extends 
CompositeTypeSerializerConfigSnapshot {
 
+               private static final int VERSION = 1;
+
                public TimerSerializerConfigSnapshot() {
                }
 
-               public TimerSerializerConfigSnapshot(TypeSerializer<K> 
keySerializer, TypeSerializer<N> namespaceSerializer) {
-                       super(keySerializer, namespaceSerializer);
+               public TimerSerializerConfigSnapshot(
+                       @Nonnull TypeSerializer<K> keySerializer,
+                       @Nonnull TypeSerializer<N> namespaceSerializer) {
+                       super(init(keySerializer, namespaceSerializer));
+               }
+
+               private static TypeSerializer<?>[] init(
+                       @Nonnull TypeSerializer<?> keySerializer,
+                       @Nonnull TypeSerializer<?> namespaceSerializer) {
+                       TypeSerializer<?>[] timerSerializers = new 
TypeSerializer[2];
+                       timerSerializers[KEY_SERIALIZER_SNAPSHOT_INDEX] = 
keySerializer;
+                       timerSerializers[NAMESPACE_SERIALIZER_SNAPSHOT_INDEX] = 
namespaceSerializer;
+                       return timerSerializers;
                }
 
                @Override
                public int getVersion() {
-                       return 0;
+                       return VERSION;
                }
        }
 }
diff --git 
a/flink-core/src/main/java/org/apache/flink/util/StateMigrationException.java 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManagerTest.java
similarity index 51%
copy from 
flink-core/src/main/java/org/apache/flink/util/StateMigrationException.java
copy to 
flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManagerTest.java
index 00e0e73..905e8d7 100644
--- 
a/flink-core/src/main/java/org/apache/flink/util/StateMigrationException.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManagerTest.java
@@ -16,23 +16,26 @@
  * limitations under the License.
  */
 
-package org.apache.flink.util;
+package org.apache.flink.streaming.api.operators;
 
-/**
- * Base class for state migration related exceptions.
- */
-public class StateMigrationException extends FlinkException {
-       private static final long serialVersionUID = 8268516412747670839L;
+import org.apache.flink.util.TestLogger;
 
-       public StateMigrationException(String message) {
-               super(message);
-       }
+import org.junit.Assert;
+import org.junit.Test;
 
-       public StateMigrationException(Throwable cause) {
-               super(cause);
-       }
+/**
+ * Tests for {@link InternalTimeServiceManager}.
+ */
+public class InternalTimeServiceManagerTest extends TestLogger {
 
-       public StateMigrationException(String message, Throwable cause) {
-               super(message, cause);
+       /**
+        * This test fixes some constants, because changing them can harm 
backwards compatibility.
+        */
+       @Test
+       public void fixConstants() {
+               String expectedTimerStatePrefix = "_timer_state";
+               Assert.assertEquals(expectedTimerStatePrefix, 
InternalTimeServiceManager.TIMER_STATE_PREFIX);
+               Assert.assertEquals(expectedTimerStatePrefix + "/processing_", 
InternalTimeServiceManager.PROCESSING_TIMER_PREFIX);
+               Assert.assertEquals(expectedTimerStatePrefix + "/event_", 
InternalTimeServiceManager.EVENT_TIMER_PREFIX);
        }
 }
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/TimerSerializerTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/TimerSerializerTest.java
new file mode 100644
index 0000000..9fe4ffc
--- /dev/null
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/TimerSerializerTest.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.typeutils.SerializerTestBase;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+
+/**
+ * Test for {@link TimerSerializer}.
+ */
+public class TimerSerializerTest extends 
SerializerTestBase<TimerHeapInternalTimer<Long, TimeWindow>> {
+
+       private static final TypeSerializer<Long> KEY_SERIALIZER = 
LongSerializer.INSTANCE;
+       private static final TypeSerializer<TimeWindow> NAMESPACE_SERIALIZER = 
new TimeWindow.Serializer();
+
+       @Override
+       protected TypeSerializer<TimerHeapInternalTimer<Long, TimeWindow>> 
createSerializer() {
+               return new TimerSerializer<>(KEY_SERIALIZER, 
NAMESPACE_SERIALIZER);
+       }
+
+       @Override
+       protected int getLength() {
+               return Long.BYTES + KEY_SERIALIZER.getLength() + 
NAMESPACE_SERIALIZER.getLength();
+       }
+
+       @SuppressWarnings("unchecked")
+       @Override
+       protected Class<TimerHeapInternalTimer<Long, TimeWindow>> 
getTypeClass() {
+               return (Class<TimerHeapInternalTimer<Long, TimeWindow>>) 
(Class<?>) TimerHeapInternalTimer.class;
+       }
+
+       @SuppressWarnings("unchecked")
+       @Override
+       protected TimerHeapInternalTimer<Long, TimeWindow>[] getTestData() {
+               return (TimerHeapInternalTimer<Long, TimeWindow>[]) new 
TimerHeapInternalTimer[]{
+                       new TimerHeapInternalTimer<>(42L, 4711L, new 
TimeWindow(1000L, 2000L)),
+                       new TimerHeapInternalTimer<>(0L, 0L, new TimeWindow(0L, 
0L)),
+                       new TimerHeapInternalTimer<>(1L, -1L, new 
TimeWindow(1L, -1L)),
+                       new TimerHeapInternalTimer<>(-1L, 1L, new 
TimeWindow(-1L, 1L)),
+                       new TimerHeapInternalTimer<>(Long.MAX_VALUE, 
Long.MIN_VALUE, new TimeWindow(Long.MAX_VALUE, Long.MIN_VALUE)),
+                       new TimerHeapInternalTimer<>(Long.MIN_VALUE, 
Long.MAX_VALUE, new TimeWindow(Long.MIN_VALUE, Long.MAX_VALUE))
+               };
+       }
+}
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/TriggerTestHarness.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/TriggerTestHarness.java
index 1536956..bc5bb1b 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/TriggerTestHarness.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/TriggerTestHarness.java
@@ -115,11 +115,11 @@ public class TriggerTestHarness<T, W extends Window> {
        }
 
        public int numStateEntries() {
-               return stateBackend.numStateEntries();
+               return stateBackend.numKeyValueStateEntries();
        }
 
        public int numStateEntries(W window) {
-               return stateBackend.numStateEntries(window);
+               return stateBackend.numKeyValueStateEntries(window);
        }
 
        /**
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
index 2035c46..caf846f 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
@@ -72,7 +72,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, 
OUT>
                AbstractStreamOperator<?> abstractStreamOperator = 
(AbstractStreamOperator<?>) operator;
                KeyedStateBackend<Object> keyedStateBackend = 
abstractStreamOperator.getKeyedStateBackend();
                if (keyedStateBackend instanceof HeapKeyedStateBackend) {
-                       return ((HeapKeyedStateBackend) 
keyedStateBackend).numStateEntries();
+                       return ((HeapKeyedStateBackend) 
keyedStateBackend).numKeyValueStateEntries();
                } else {
                        throw new UnsupportedOperationException();
                }
@@ -82,7 +82,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, 
OUT>
                AbstractStreamOperator<?> abstractStreamOperator = 
(AbstractStreamOperator<?>) operator;
                KeyedStateBackend<Object> keyedStateBackend = 
abstractStreamOperator.getKeyedStateBackend();
                if (keyedStateBackend instanceof HeapKeyedStateBackend) {
-                       return ((HeapKeyedStateBackend) 
keyedStateBackend).numStateEntries(namespace);
+                       return ((HeapKeyedStateBackend) 
keyedStateBackend).numKeyValueStateEntries(namespace);
                } else {
                        throw new UnsupportedOperationException();
                }
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
index 607eee0..c00e59a 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
@@ -62,7 +62,7 @@ public class KeyedTwoInputStreamOperatorTestHarness<K, IN1, 
IN2, OUT>
                AbstractStreamOperator<?> abstractStreamOperator = 
(AbstractStreamOperator<?>) operator;
                KeyedStateBackend<Object> keyedStateBackend = 
abstractStreamOperator.getKeyedStateBackend();
                if (keyedStateBackend instanceof HeapKeyedStateBackend) {
-                       return ((HeapKeyedStateBackend) 
keyedStateBackend).numStateEntries();
+                       return ((HeapKeyedStateBackend) 
keyedStateBackend).numKeyValueStateEntries();
                } else {
                        throw new UnsupportedOperationException();
                }

Reply via email to