[FLINK-5421] Add explicit restore() method in Snapshotable

Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/39fc07f8
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/39fc07f8
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/39fc07f8

Branch: refs/heads/release-1.2
Commit: 39fc07f87edd33ee78459b2d08b9d767efb100cc
Parents: 80f1517
Author: Stefan Richter <[email protected]>
Authored: Thu Jan 5 23:45:13 2017 +0100
Committer: Aljoscha Krettek <[email protected]>
Committed: Thu Jan 12 16:41:33 2017 +0100

----------------------------------------------------------------------
 .../state/RocksDBKeyedStateBackend.java         |  67 +++-----
 .../streaming/state/RocksDBStateBackend.java    |  37 -----
 .../runtime/state/AbstractStateBackend.java     |  32 +---
 .../state/DefaultOperatorStateBackend.java      | 151 ++++++++-----------
 .../flink/runtime/state/Snapshotable.java       |   9 ++
 .../state/StateInitializationContextImpl.java   |  66 +++++---
 .../state/filesystem/FsStateBackend.java        |  21 ---
 .../state/heap/HeapKeyedStateBackend.java       |  37 ++---
 .../state/memory/MemoryStateBackend.java        |  22 ---
 .../runtime/state/OperatorStateBackendTest.java |  11 +-
 .../runtime/state/StateBackendTestBase.java     |  21 ++-
 .../streaming/runtime/tasks/StreamTask.java     |  48 +++---
 .../runtime/tasks/BlockingCheckpointsTest.java  |  13 --
 .../tasks/InterruptSensitiveRestoreTest.java    | 122 ++++++++++++++-
 .../util/AbstractStreamOperatorTestHarness.java |  15 +-
 .../KeyedOneInputStreamOperatorTestHarness.java |  37 ++---
 .../KeyedTwoInputStreamOperatorTestHarness.java |  33 ++--
 .../streaming/runtime/StateBackendITCase.java   |  15 --
 18 files changed, 353 insertions(+), 404 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
 
b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 1c0a4b7..71e2c79 100644
--- 
a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ 
b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -180,51 +180,6 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                kvStateInformation = new HashMap<>();
        }
 
-       public RocksDBKeyedStateBackend(
-                       JobID jobId,
-                       String operatorIdentifier,
-                       ClassLoader userCodeClassLoader,
-                       File instanceBasePath,
-                       DBOptions dbOptions,
-                       ColumnFamilyOptions columnFamilyOptions,
-                       TaskKvStateRegistry kvStateRegistry,
-                       TypeSerializer<K> keySerializer,
-                       int numberOfKeyGroups,
-                       KeyGroupRange keyGroupRange,
-                       Collection<KeyGroupsStateHandle> restoreState
-       ) throws Exception {
-
-               this(jobId,
-                       operatorIdentifier,
-                       userCodeClassLoader,
-                       instanceBasePath,
-                       dbOptions,
-                       columnFamilyOptions,
-                       kvStateRegistry,
-                       keySerializer,
-                       numberOfKeyGroups,
-                       keyGroupRange);
-
-               LOG.info("Initializing RocksDB keyed state backend from 
snapshot.");
-
-               if (LOG.isDebugEnabled()) {
-                       LOG.debug("Restoring snapshot from state handles: {}.", 
restoreState);
-               }
-
-               try {
-                       if 
(MigrationUtil.isOldSavepointKeyedState(restoreState)) {
-                               LOG.info("Converting RocksDB state from old 
savepoint.");
-                               restoreOldSavepointKeyedState(restoreState);
-                       } else {
-                               RocksDBRestoreOperation restoreOperation = new 
RocksDBRestoreOperation(this);
-                               restoreOperation.doRestore(restoreState);
-                       }
-               } catch (Exception ex) {
-                       dispose();
-                       throw ex;
-               }
-       }
-
        /**
         * Should only be called by one thread, and only after all accesses to 
the DB happened.
         */
@@ -631,6 +586,28 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                }
        }
 
+       @Override
+       public void restore(Collection<KeyGroupsStateHandle> restoreState) 
throws Exception {
+               LOG.info("Initializing RocksDB keyed state backend from 
snapshot.");
+
+               if (LOG.isDebugEnabled()) {
+                       LOG.debug("Restoring snapshot from state handles: {}.", 
restoreState);
+               }
+
+               try {
+                       if 
(MigrationUtil.isOldSavepointKeyedState(restoreState)) {
+                               LOG.info("Converting RocksDB state from old 
savepoint.");
+                               restoreOldSavepointKeyedState(restoreState);
+                       } else {
+                               RocksDBRestoreOperation restoreOperation = new 
RocksDBRestoreOperation(this);
+                               restoreOperation.doRestore(restoreState);
+                       }
+               } catch (Exception ex) {
+                       dispose();
+                       throw ex;
+               }
+       }
+
        /**
         * Encapsulates the process of restoring a RocksDBKeyedStateBackend 
from a snapshot.
         */

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
 
b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
index c2e33d4..1e5620f 100644
--- 
a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
+++ 
b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
@@ -28,13 +28,10 @@ import 
org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.apache.flink.util.AbstractID;
-
 import org.rocksdb.ColumnFamilyOptions;
 import org.rocksdb.DBOptions;
-
 import org.rocksdb.NativeLibraryLoader;
 import org.rocksdb.RocksDB;
 import org.slf4j.Logger;
@@ -46,7 +43,6 @@ import java.lang.reflect.Field;
 import java.net.URI;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collection;
 import java.util.List;
 import java.util.Random;
 import java.util.UUID;
@@ -262,39 +258,6 @@ public class RocksDBStateBackend extends 
AbstractStateBackend {
                                keyGroupRange);
        }
 
-       @Override
-       public <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
-                       Environment env,
-                       JobID jobID,
-                       String operatorIdentifier,
-                       TypeSerializer<K> keySerializer,
-                       int numberOfKeyGroups,
-                       KeyGroupRange keyGroupRange,
-                       Collection<KeyGroupsStateHandle> restoredState,
-                       TaskKvStateRegistry kvStateRegistry) throws Exception {
-
-               // first, make sure that the RocksDB JNI library is loaded
-               // we do this explicitly here to have better error handling
-               String tempDir = 
env.getTaskManagerInfo().getTmpDirectories()[0];
-               ensureRocksDBIsLoaded(tempDir);
-
-               lazyInitializeForJob(env, operatorIdentifier);
-
-               File instanceBasePath = new File(getDbPath(), 
UUID.randomUUID().toString());
-               return new RocksDBKeyedStateBackend<>(
-                               jobID,
-                               operatorIdentifier,
-                               env.getUserClassLoader(),
-                               instanceBasePath,
-                               getDbOptions(),
-                               getColumnOptions(),
-                               kvStateRegistry,
-                               keySerializer,
-                               numberOfKeyGroups,
-                               keyGroupRange,
-                               restoredState);
-       }
-
        // 
------------------------------------------------------------------------
        //  Parameters
        // 
------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
index 1b53f1a..60d035a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
@@ -24,7 +24,6 @@ import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 
 import java.io.IOException;
-import java.util.Collection;
 
 /**
  * A state backend defines how state is stored and snapshotted during 
checkpoints.
@@ -59,41 +58,12 @@ public abstract class AbstractStateBackend implements 
java.io.Serializable {
        ) throws Exception;
 
        /**
-        * Creates a new {@link AbstractKeyedStateBackend} that restores its 
state from the given list
-        * {@link KeyGroupsStateHandle KeyGroupStateHandles}.
-        */
-       public abstract <K> AbstractKeyedStateBackend<K> 
restoreKeyedStateBackend(
-                       Environment env,
-                       JobID jobID,
-                       String operatorIdentifier,
-                       TypeSerializer<K> keySerializer,
-                       int numberOfKeyGroups,
-                       KeyGroupRange keyGroupRange,
-                       Collection<KeyGroupsStateHandle> restoredState,
-                       TaskKvStateRegistry kvStateRegistry
-       ) throws Exception;
-
-
-       /**
         * Creates a new {@link OperatorStateBackend} that can be used for 
storing partitionable operator
         * state in checkpoint streams.
         */
        public OperatorStateBackend createOperatorStateBackend(
                        Environment env,
-                       String operatorIdentifier
-       ) throws Exception {
+                       String operatorIdentifier) throws Exception {
                return new 
DefaultOperatorStateBackend(env.getUserClassLoader());
        }
-
-       /**
-        * Creates a new {@link OperatorStateBackend} that restores its state 
from the given collection of
-        * {@link OperatorStateHandle}.
-        */
-       public OperatorStateBackend restoreOperatorStateBackend(
-                       Environment env,
-                       String operatorIdentifier,
-                       Collection<OperatorStateHandle> restoreSnapshots
-       ) throws Exception {
-               return new 
DefaultOperatorStateBackend(env.getUserClassLoader(), restoreSnapshots);
-       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
index d7a10d5..10bb409 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
@@ -50,33 +50,16 @@ public class DefaultOperatorStateBackend implements 
OperatorStateBackend {
        public static final String DEFAULT_OPERATOR_STATE_NAME = "_default_";
        
        private final Map<String, PartitionableListState<?>> registeredStates;
-       private final Collection<OperatorStateHandle> restoreSnapshots;
        private final CloseableRegistry closeStreamOnCancelRegistry;
        private final JavaSerializer<Serializable> javaSerializer;
        private final ClassLoader userClassloader;
 
-       /**
-        * Restores a OperatorStateStore (lazily) using the provided snapshots.
-        *
-        * @param restoreSnapshots snapshots that are available to restore 
partitionable states on request.
-        */
-       public DefaultOperatorStateBackend(
-                       ClassLoader userClassLoader,
-                       Collection<OperatorStateHandle> restoreSnapshots) 
throws IOException {
+       public DefaultOperatorStateBackend(ClassLoader userClassLoader) throws 
IOException {
 
+               this.closeStreamOnCancelRegistry = new CloseableRegistry();
                this.userClassloader = 
Preconditions.checkNotNull(userClassLoader);
                this.javaSerializer = new JavaSerializer<>();
                this.registeredStates = new HashMap<>();
-               this.closeStreamOnCancelRegistry = new CloseableRegistry();
-               this.restoreSnapshots = restoreSnapshots;
-               restoreState();
-       }
-
-       /**
-        * Creates an empty OperatorStateStore.
-        */
-       public DefaultOperatorStateBackend(ClassLoader userClassLoader) throws 
IOException {
-               this(userClassLoader, null);
        }
 
        @SuppressWarnings("unchecked")
@@ -111,69 +94,6 @@ public class DefaultOperatorStateBackend implements 
OperatorStateBackend {
                return partitionableListState;
        }
 
-       private void restoreState() throws IOException {
-
-               if (null == restoreSnapshots) {
-                       return;
-               }
-
-               for (OperatorStateHandle stateHandle : restoreSnapshots) {
-
-                       if (stateHandle == null) {
-                               continue;
-                       }
-
-                       FSDataInputStream in = stateHandle.openInputStream();
-                       closeStreamOnCancelRegistry.registerClosable(in);
-
-                       ClassLoader restoreClassLoader = 
Thread.currentThread().getContextClassLoader();
-
-                       try {
-                               
Thread.currentThread().setContextClassLoader(userClassloader);
-                               OperatorBackendSerializationProxy 
backendSerializationProxy =
-                                               new 
OperatorBackendSerializationProxy(userClassloader);
-
-                               backendSerializationProxy.read(new 
DataInputViewStreamWrapper(in));
-
-                               
List<OperatorBackendSerializationProxy.StateMetaInfo<?>> metaInfoList =
-                                               
backendSerializationProxy.getNamedStateSerializationProxies();
-
-                               // Recreate all PartitionableListStates from 
the meta info
-                               for 
(OperatorBackendSerializationProxy.StateMetaInfo<?> stateMetaInfo : 
metaInfoList) {
-                                       PartitionableListState<?> listState = 
registeredStates.get(stateMetaInfo.getName());
-
-                                       if (null == listState) {
-                                               listState = new 
PartitionableListState<>(
-                                                               
stateMetaInfo.getName(),
-                                                               
stateMetaInfo.getStateSerializer());
-
-                                               
registeredStates.put(listState.getName(), listState);
-                                       } else {
-                                               
Preconditions.checkState(listState.getPartitionStateSerializer().isCompatibleWith(
-                                                               
stateMetaInfo.getStateSerializer()), "Incompatible state serializers found: " +
-                                                               
listState.getPartitionStateSerializer() + " is not compatible with " +
-                                                               
stateMetaInfo.getStateSerializer());
-                                       }
-                               }
-
-                               // Restore all the state in 
PartitionableListStates
-                               for (Map.Entry<String, long[]> nameToOffsets : 
stateHandle.getStateNameToPartitionOffsets().entrySet()) {
-                                       PartitionableListState<?> 
stateListForName = registeredStates.get(nameToOffsets.getKey());
-
-                                       Preconditions.checkState(null != 
stateListForName, "Found state without " +
-                                                       "corresponding meta 
info: " + nameToOffsets.getKey());
-
-                                       
deserializeStateValues(stateListForName, in, nameToOffsets.getValue());
-                               }
-
-                       } finally {
-                               
Thread.currentThread().setContextClassLoader(restoreClassLoader);
-                               
closeStreamOnCancelRegistry.unregisterClosable(in);
-                               IOUtils.closeQuietly(in);
-                       }
-               }
-       }
-
        private static <S> void deserializeStateValues(
                        PartitionableListState<S> stateListForName,
                        FSDataInputStream in,
@@ -239,6 +159,70 @@ public class DefaultOperatorStateBackend implements 
OperatorStateBackend {
        }
 
        @Override
+       public void restore(Collection<OperatorStateHandle> restoreSnapshots) 
throws Exception {
+
+               if (null == restoreSnapshots) {
+                       return;
+               }
+
+               for (OperatorStateHandle stateHandle : restoreSnapshots) {
+
+                       if (stateHandle == null) {
+                               continue;
+                       }
+
+                       FSDataInputStream in = stateHandle.openInputStream();
+                       closeStreamOnCancelRegistry.registerClosable(in);
+
+                       ClassLoader restoreClassLoader = 
Thread.currentThread().getContextClassLoader();
+
+                       try {
+                               
Thread.currentThread().setContextClassLoader(userClassloader);
+                               OperatorBackendSerializationProxy 
backendSerializationProxy =
+                                               new 
OperatorBackendSerializationProxy(userClassloader);
+
+                               backendSerializationProxy.read(new 
DataInputViewStreamWrapper(in));
+
+                               
List<OperatorBackendSerializationProxy.StateMetaInfo<?>> metaInfoList =
+                                               
backendSerializationProxy.getNamedStateSerializationProxies();
+
+                               // Recreate all PartitionableListStates from 
the meta info
+                               for 
(OperatorBackendSerializationProxy.StateMetaInfo<?> stateMetaInfo : 
metaInfoList) {
+                                       PartitionableListState<?> listState = 
registeredStates.get(stateMetaInfo.getName());
+
+                                       if (null == listState) {
+                                               listState = new 
PartitionableListState<>(
+                                                               
stateMetaInfo.getName(),
+                                                               
stateMetaInfo.getStateSerializer());
+
+                                               
registeredStates.put(listState.getName(), listState);
+                                       } else {
+                                               
Preconditions.checkState(listState.getPartitionStateSerializer().isCompatibleWith(
+                                                               
stateMetaInfo.getStateSerializer()), "Incompatible state serializers found: " +
+                                                               
listState.getPartitionStateSerializer() + " is not compatible with " +
+                                                               
stateMetaInfo.getStateSerializer());
+                                       }
+                               }
+
+                               // Restore all the state in 
PartitionableListStates
+                               for (Map.Entry<String, long[]> nameToOffsets : 
stateHandle.getStateNameToPartitionOffsets().entrySet()) {
+                                       PartitionableListState<?> 
stateListForName = registeredStates.get(nameToOffsets.getKey());
+
+                                       Preconditions.checkState(null != 
stateListForName, "Found state without " +
+                                                       "corresponding meta 
info: " + nameToOffsets.getKey());
+
+                                       
deserializeStateValues(stateListForName, in, nameToOffsets.getValue());
+                               }
+
+                       } finally {
+                               
Thread.currentThread().setContextClassLoader(restoreClassLoader);
+                               
closeStreamOnCancelRegistry.unregisterClosable(in);
+                               IOUtils.closeQuietly(in);
+                       }
+               }
+       }
+
+       @Override
        public void dispose() {
                registeredStates.clear();
        }
@@ -314,5 +298,4 @@ public class DefaultOperatorStateBackend implements 
OperatorStateBackend {
                                        '}';
                }
        }
-}
-
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
index 2aa282d..a4a6bc4 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state;
 
+import java.util.Collection;
 import java.util.concurrent.RunnableFuture;
 
 /**
@@ -42,4 +43,12 @@ public interface Snapshotable<S extends StateObject> {
                        long checkpointId,
                        long timestamp,
                        CheckpointStreamFactory streamFactory) throws Exception;
+
+       /**
+        * Restores state that was previously snapshotted from the provided 
parameters. Typically the parameters are state
+        * handles from which the old state is read.
+        *
+        * @param state the old state to restore.
+        */
+       void restore(Collection<S> state) throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
index c86ff6c..be59a2a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
@@ -30,6 +30,7 @@ import java.io.IOException;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
+import java.util.NoSuchElementException;
 
 /**
  * Default implementation of {@link StateInitializationContext}.
@@ -155,19 +156,21 @@ public class StateInitializationContextImpl implements 
StateInitializationContex
                public boolean hasNext() {
                        if (null != currentStateHandle && 
currentOffsetsIterator.hasNext()) {
                                return true;
-                       } else {
-                               while (stateHandleIterator.hasNext()) {
-                                       currentStateHandle = 
stateHandleIterator.next();
-                                       if 
(currentStateHandle.getNumberOfKeyGroups() > 0) {
-                                               currentOffsetsIterator = 
currentStateHandle.getGroupRangeOffsets().iterator();
-                                               
closableRegistry.unregisterClosable(currentStream);
-                                               
IOUtils.closeQuietly(currentStream);
-                                               currentStream = null;
-                                               return true;
-                                       }
+                       }
+
+                       while (stateHandleIterator.hasNext()) {
+                               currentStateHandle = stateHandleIterator.next();
+                               if (currentStateHandle.getNumberOfKeyGroups() > 
0) {
+                                       currentOffsetsIterator = 
currentStateHandle.getGroupRangeOffsets().iterator();
+                                       
closableRegistry.unregisterClosable(currentStream);
+                                       IOUtils.closeQuietly(currentStream);
+                                       currentStream = null;
+
+                                       return true;
                                }
-                               return false;
                        }
+
+                       return false;
                }
 
                private void openStream() throws IOException {
@@ -178,6 +181,11 @@ public class StateInitializationContextImpl implements 
StateInitializationContex
 
                @Override
                public KeyGroupStatePartitionStreamProvider next() {
+
+                       if (!hasNext()) {
+                               throw new NoSuchElementException("Iterator 
exhausted");
+                       }
+
                        Tuple2<Integer, Long> keyGroupOffset = 
currentOffsetsIterator.next();
                        try {
                                if (null == currentStream) {
@@ -220,26 +228,28 @@ public class StateInitializationContextImpl implements 
StateInitializationContex
 
                @Override
                public boolean hasNext() {
-                       if (null != currentStateHandle && offPos < 
offsets.length) {
+
+                       if (null != offsets && offPos < offsets.length) {
                                return true;
-                       } else {
-                               while (stateHandleIterator.hasNext()) {
-                                       currentStateHandle = 
stateHandleIterator.next();
-                                       long[] offsets = 
currentStateHandle.getStateNameToPartitionOffsets().get(stateName);
-                                       if (null != offsets && offsets.length > 
0) {
+                       }
+
+                       while (stateHandleIterator.hasNext()) {
+                               currentStateHandle = stateHandleIterator.next();
+                               long[] offsets = 
currentStateHandle.getStateNameToPartitionOffsets().get(stateName);
+                               if (null != offsets && offsets.length > 0) {
 
-                                               this.offsets = offsets;
-                                               this.offPos = 0;
+                                       this.offsets = offsets;
+                                       this.offPos = 0;
 
-                                               
closableRegistry.unregisterClosable(currentStream);
-                                               
IOUtils.closeQuietly(currentStream);
-                                               currentStream = null;
+                                       
closableRegistry.unregisterClosable(currentStream);
+                                       IOUtils.closeQuietly(currentStream);
+                                       currentStream = null;
 
-                                               return true;
-                                       }
+                                       return true;
                                }
-                               return false;
                        }
+
+                       return false;
                }
 
                private void openStream() throws IOException {
@@ -250,7 +260,13 @@ public class StateInitializationContextImpl implements 
StateInitializationContex
 
                @Override
                public StatePartitionStreamProvider next() {
+
+                       if (!hasNext()) {
+                               throw new NoSuchElementException("Iterator 
exhausted");
+                       }
+
                        long offset = offsets[offPos++];
+
                        try {
                                if (null == currentStream) {
                                        openStream();

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
index 4e15cd5..281dbb0 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
@@ -28,7 +28,6 @@ import 
org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -36,7 +35,6 @@ import org.slf4j.LoggerFactory;
 import java.io.IOException;
 import java.net.URI;
 import java.net.URISyntaxException;
-import java.util.Collection;
 
 /**
  * The file state backend is a state backend that stores the state of 
streaming jobs in a file system.
@@ -192,25 +190,6 @@ public class FsStateBackend extends AbstractStateBackend {
        }
 
        @Override
-       public <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
-                       Environment env,
-                       JobID jobID,
-                       String operatorIdentifier,
-                       TypeSerializer<K> keySerializer,
-                       int numberOfKeyGroups,
-                       KeyGroupRange keyGroupRange,
-                       Collection<KeyGroupsStateHandle> restoredState,
-                       TaskKvStateRegistry kvStateRegistry) throws Exception {
-               return new HeapKeyedStateBackend<>(
-                               kvStateRegistry,
-                               keySerializer,
-                               env.getUserClassLoader(),
-                               numberOfKeyGroups,
-                               keyGroupRange,
-                               restoredState);
-       }
-
-       @Override
        public String toString() {
                return "File State Backend @ " + basePath;
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index d07901b..d461dfd 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -101,28 +101,6 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                LOG.info("Initializing heap keyed state backend with stream 
factory.");
        }
 
-       public HeapKeyedStateBackend(
-                       TaskKvStateRegistry kvStateRegistry,
-                       TypeSerializer<K> keySerializer,
-                       ClassLoader userCodeClassLoader,
-                       int numberOfKeyGroups,
-                       KeyGroupRange keyGroupRange,
-                       Collection<KeyGroupsStateHandle> restoredState) throws 
Exception {
-               super(kvStateRegistry, keySerializer, userCodeClassLoader, 
numberOfKeyGroups, keyGroupRange);
-
-               LOG.info("Initializing heap keyed state backend from 
snapshot.");
-
-               if (LOG.isDebugEnabled()) {
-                       LOG.debug("Restoring snapshot from state handles: {}.", 
restoredState);
-               }
-
-               if (MigrationUtil.isOldSavepointKeyedState(restoredState)) {
-                       restoreOldSavepointKeyedState(restoredState);
-               } else {
-                       restorePartitionedState(restoredState);
-               }
-       }
-
        // 
------------------------------------------------------------------------
        //  state backend operations
        // 
------------------------------------------------------------------------
@@ -251,6 +229,21 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                }
        }
 
+       @Override
+       public void restore(Collection<KeyGroupsStateHandle> restoredState) 
throws Exception {
+               LOG.info("Initializing heap keyed state backend from 
snapshot.");
+
+               if (LOG.isDebugEnabled()) {
+                       LOG.debug("Restoring snapshot from state handles: {}.", 
restoredState);
+               }
+
+               if (MigrationUtil.isOldSavepointKeyedState(restoredState)) {
+                       restoreOldSavepointKeyedState(restoredState);
+               } else {
+                       restorePartitionedState(restoredState);
+               }
+       }
+
        private <N, S> void writeStateTableForKeyGroup(
                        DataOutputView outView,
                        StateTable<K, N, S> stateTable,

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
index 33f03ad..58a86df 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
@@ -26,11 +26,9 @@ import 
org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
 
 import java.io.IOException;
-import java.util.Collection;
 
 /**
  * A {@link AbstractStateBackend} that stores all its data and checkpoints in 
memory and has no
@@ -92,24 +90,4 @@ public class MemoryStateBackend extends AbstractStateBackend 
{
                                numberOfKeyGroups,
                                keyGroupRange);
        }
-
-       @Override
-       public <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
-                       Environment env, JobID jobID,
-                       String operatorIdentifier,
-                       TypeSerializer<K> keySerializer,
-                       int numberOfKeyGroups,
-                       KeyGroupRange keyGroupRange,
-                       Collection<KeyGroupsStateHandle> restoredState,
-                       TaskKvStateRegistry kvStateRegistry) throws Exception {
-
-               return new HeapKeyedStateBackend<>(
-                               kvStateRegistry,
-                               keySerializer,
-                               env.getUserClassLoader(),
-                               numberOfKeyGroups,
-                               keyGroupRange,
-                               restoredState);
-       }
-
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
index 648d762..515011f 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
@@ -45,7 +45,9 @@ public class OperatorStateBackendTest {
        }
 
        private OperatorStateBackend createNewOperatorStateBackend() throws 
Exception {
-               return 
abstractStateBackend.createOperatorStateBackend(createMockEnvironment(), 
"test-operator");
+               return abstractStateBackend.createOperatorStateBackend(
+                               createMockEnvironment(),
+                               "test-operator");
        }
 
        @Test
@@ -131,8 +133,11 @@ public class OperatorStateBackendTest {
 
                        operatorStateBackend.dispose();
 
-                       operatorStateBackend = 
abstractStateBackend.restoreOperatorStateBackend(
-                                       createMockEnvironment(), 
"testOperator", Collections.singletonList(stateHandle));
+                       operatorStateBackend = 
abstractStateBackend.createOperatorStateBackend(
+                                       createMockEnvironment(),
+                                       "testOperator");
+
+                       
operatorStateBackend.restore(Collections.singletonList(stateHandle));
 
                        assertEquals(2, 
operatorStateBackend.getRegisteredStateNames().size());
 

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 5655f1c..9bc4c53 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -58,7 +58,13 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.RunnableFuture;
 
 import static org.hamcrest.Matchers.containsInAnyOrder;
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
@@ -101,8 +107,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                                keySerializer,
                                numberOfKeyGroups,
                                keyGroupRange,
-                               env.getTaskKvStateRegistry())
-;
+                               env.getTaskKvStateRegistry());
        }
 
        protected <K> AbstractKeyedStateBackend<K> 
restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyGroupsStateHandle 
state) throws Exception {
@@ -127,15 +132,21 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        KeyGroupRange keyGroupRange,
                        List<KeyGroupsStateHandle> state,
                        Environment env) throws Exception {
-               return getStateBackend().restoreKeyedStateBackend(
+
+               AbstractKeyedStateBackend<K> backend = 
getStateBackend().createKeyedStateBackend(
                                env,
                                new JobID(),
                                "test_op",
                                keySerializer,
                                numberOfKeyGroups,
                                keyGroupRange,
-                               state,
                                env.getTaskKvStateRegistry());
+
+               if (null != state) {
+                       backend.restore(state);
+               }
+
+               return backend;
        }
 
        @Test

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index bd9215a..3bbc53b 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -741,13 +741,17 @@ public abstract class StreamTask<OUT, OP extends 
StreamOperator<OUT>>
                Environment env = getEnvironment();
                String opId = createOperatorIdentifier(op, 
getConfiguration().getVertexID());
 
-               OperatorStateBackend newBackend = restoreStateHandles == null ?
-                               stateBackend.createOperatorStateBackend(env, 
opId)
-                               : stateBackend.restoreOperatorStateBackend(env, 
opId, restoreStateHandles);
+               OperatorStateBackend operatorStateBackend = 
stateBackend.createOperatorStateBackend(env, opId);
 
-               cancelables.registerClosable(newBackend);
+               // let operator state backend participate in the operator 
lifecycle, i.e. make it responsive to cancelation
+               cancelables.registerClosable(operatorStateBackend);
 
-               return newBackend;
+               // restore if we have some old state
+               if (null != restoreStateHandles) {
+                       operatorStateBackend.restore(restoreStateHandles);
+               }
+
+               return operatorStateBackend;
        }
 
        public <K> AbstractKeyedStateBackend<K> createKeyedStateBackend(
@@ -763,29 +767,23 @@ public abstract class StreamTask<OUT, OP extends 
StreamOperator<OUT>>
                                headOperator,
                                configuration.getVertexID());
 
-               if (null != restoreStateHandles && null != 
restoreStateHandles.getManagedKeyedState()) {
-                       keyedStateBackend = 
stateBackend.restoreKeyedStateBackend(
-                                       getEnvironment(),
-                                       getEnvironment().getJobID(),
-                                       operatorIdentifier,
-                                       keySerializer,
-                                       numberOfKeyGroups,
-                                       keyGroupRange,
-                                       
restoreStateHandles.getManagedKeyedState(),
-                                       
getEnvironment().getTaskKvStateRegistry());
-               } else {
-                       keyedStateBackend = 
stateBackend.createKeyedStateBackend(
-                                       getEnvironment(),
-                                       getEnvironment().getJobID(),
-                                       operatorIdentifier,
-                                       keySerializer,
-                                       numberOfKeyGroups,
-                                       keyGroupRange,
-                                       
getEnvironment().getTaskKvStateRegistry());
-               }
+               keyedStateBackend = stateBackend.createKeyedStateBackend(
+                               getEnvironment(),
+                               getEnvironment().getJobID(),
+                               operatorIdentifier,
+                               keySerializer,
+                               numberOfKeyGroups,
+                               keyGroupRange,
+                               getEnvironment().getTaskKvStateRegistry());
 
+               // let keyed state backend participate in the operator 
lifecycle, i.e. make it responsive to cancelation
                cancelables.registerClosable(keyedStateBackend);
 
+               // restore if we have some old state
+               if (null != restoreStateHandles && null != 
restoreStateHandles.getManagedKeyedState()) {
+                       
keyedStateBackend.restore(restoreStateHandles.getManagedKeyedState());
+               }
+
                @SuppressWarnings("unchecked")
                AbstractKeyedStateBackend<K> typedBackend = 
(AbstractKeyedStateBackend<K>) keyedStateBackend;
                return typedBackend;

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java
index 291fd5f..5d2b106 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java
@@ -51,7 +51,6 @@ import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import 
org.apache.flink.runtime.state.CheckpointStreamFactory.CheckpointStateOutputStream;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
 import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -63,12 +62,10 @@ import org.apache.flink.runtime.util.EnvironmentInformation;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.StreamFilter;
 import org.apache.flink.util.SerializedValue;
-
 import org.junit.Test;
 
 import java.io.IOException;
 import java.net.URL;
-import java.util.Collection;
 import java.util.Collections;
 
 import static org.junit.Assert.assertEquals;
@@ -182,16 +179,6 @@ public class BlockingCheckpointsTest {
 
                        throw new UnsupportedOperationException();
                }
-
-               @Override
-               public <K> AbstractKeyedStateBackend<K> 
restoreKeyedStateBackend(
-                               Environment env, JobID jobID, String 
operatorIdentifier,
-                               TypeSerializer<K> keySerializer, int 
numberOfKeyGroups,
-                               KeyGroupRange keyGroupRange, 
Collection<KeyGroupsStateHandle> restoredState,
-                               TaskKvStateRegistry kvStateRegistry) throws 
Exception {
-
-                       throw new UnsupportedOperationException();
-               }
        }
 
        private static final class LockingOutputStreamFactory implements 
CheckpointStreamFactory {

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
index 1207cbb..fc5f65a 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.streaming.runtime.tasks;
 
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.testutils.OneShotLatch;
@@ -43,8 +44,14 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import 
org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
@@ -54,6 +61,7 @@ import 
org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 import org.apache.flink.runtime.util.EnvironmentInformation;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.StreamSource;
@@ -66,7 +74,9 @@ import java.io.Serializable;
 import java.net.URL;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.Executor;
 
 import static org.junit.Assert.assertEquals;
@@ -87,17 +97,61 @@ public class InterruptSensitiveRestoreTest {
 
        private static final OneShotLatch IN_RESTORE_LATCH = new OneShotLatch();
 
+       private static final int OPERATOR_MANAGED = 0;
+       private static final int OPERATOR_RAW = 1;
+       private static final int KEYED_MANAGED = 2;
+       private static final int KEYED_RAW = 3;
+       private static final int LEGACY = 4;
+
+       @Test
+       public void testRestoreWithInterruptLegacy() throws Exception {
+               testRestoreWithInterrupt(LEGACY);
+       }
+
+       @Test
+       public void testRestoreWithInterruptOperatorManaged() throws Exception {
+               testRestoreWithInterrupt(OPERATOR_MANAGED);
+       }
+
        @Test
-       public void testRestoreWithInterrupt() throws Exception {
+       public void testRestoreWithInterruptOperatorRaw() throws Exception {
+               testRestoreWithInterrupt(OPERATOR_RAW);
+       }
 
+       @Test
+       public void testRestoreWithInterruptKeyedManaged() throws Exception {
+               testRestoreWithInterrupt(KEYED_MANAGED);
+       }
+
+       @Test
+       public void testRestoreWithInterruptKeyedRaw() throws Exception {
+               testRestoreWithInterrupt(KEYED_RAW);
+       }
+
+       private void testRestoreWithInterrupt(int mode) throws Exception {
+
+               IN_RESTORE_LATCH.reset();
                Configuration taskConfig = new Configuration();
                StreamConfig cfg = new StreamConfig(taskConfig);
                cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
-               cfg.setStreamOperator(new StreamSource<>(new TestSource()));
+               switch (mode) {
+                       case OPERATOR_MANAGED:
+                       case OPERATOR_RAW:
+                       case KEYED_MANAGED:
+                       case KEYED_RAW:
+                               
cfg.setStateKeySerializer(IntSerializer.INSTANCE);
+                               cfg.setStreamOperator(new StreamSource<>(new 
TestSource()));
+                               break;
+                       case LEGACY:
+                               cfg.setStreamOperator(new StreamSource<>(new 
TestSourceLegacy()));
+                               break;
+                       default:
+                               throw new IllegalArgumentException();
+               }
 
                StreamStateHandle lockingHandle = new 
InterruptLockingStateHandle();
 
-               Task task = createTask(taskConfig, lockingHandle);
+               Task task = createTask(taskConfig, lockingHandle, mode);
 
                // start the task and wait until it is in "restore"
                task.startTaskThread();
@@ -122,18 +176,51 @@ public class InterruptSensitiveRestoreTest {
 
        private static Task createTask(
                        Configuration taskConfig,
-                       StreamStateHandle state) throws IOException {
+                       StreamStateHandle state,
+                       int mode) throws IOException {
 
                NetworkEnvironment networkEnvironment = 
mock(NetworkEnvironment.class);
                
when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), 
any(JobVertexID.class)))
                                .thenReturn(mock(TaskKvStateRegistry.class));
 
-               ChainedStateHandle<StreamStateHandle> operatorState = new 
ChainedStateHandle<>(Collections.singletonList(state));
+
+               ChainedStateHandle<StreamStateHandle> operatorState = null;
                List<KeyGroupsStateHandle> keyGroupStateFromBackend = 
Collections.emptyList();
                List<KeyGroupsStateHandle> keyGroupStateFromStream = 
Collections.emptyList();
                List<Collection<OperatorStateHandle>> operatorStateBackend = 
Collections.emptyList();
                List<Collection<OperatorStateHandle>> operatorStateStream = 
Collections.emptyList();
 
+               Map<String, long[]> operatorStateMetadata = new HashMap<>(1);
+               
operatorStateMetadata.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME,
 new long[]{0});
+
+               KeyGroupRangeOffsets keyGroupRangeOffsets = new 
KeyGroupRangeOffsets(new KeyGroupRange(0,0));
+
+               Collection<OperatorStateHandle> operatorStateHandles =
+                               Collections.singletonList(new 
OperatorStateHandle(operatorStateMetadata, state));
+
+               List<KeyGroupsStateHandle> keyGroupsStateHandles =
+                               Collections.singletonList(new 
KeyGroupsStateHandle(keyGroupRangeOffsets, state));
+
+               switch (mode) {
+                       case OPERATOR_MANAGED:
+                               operatorStateBackend = 
Collections.singletonList(operatorStateHandles);
+                               break;
+                       case OPERATOR_RAW:
+                               operatorStateStream = 
Collections.singletonList(operatorStateHandles);
+                               break;
+                       case KEYED_MANAGED:
+                               keyGroupStateFromBackend = 
keyGroupsStateHandles;
+                               break;
+                       case KEYED_RAW:
+                               keyGroupStateFromStream = keyGroupsStateHandles;
+                               break;
+                       case LEGACY:
+                               operatorState = new 
ChainedStateHandle<>(Collections.singletonList(state));
+                               break;
+                       default:
+                               throw new IllegalArgumentException();
+               }
+
                TaskStateHandles taskStateHandles = new TaskStateHandles(
                        operatorState,
                        operatorStateBackend,
@@ -256,7 +343,7 @@ public class InterruptSensitiveRestoreTest {
 
        // 
------------------------------------------------------------------------
 
-       private static class TestSource implements SourceFunction<Object>, 
Checkpointed<Serializable> {
+       private static class TestSourceLegacy implements 
SourceFunction<Object>, Checkpointed<Serializable> {
                private static final long serialVersionUID = 1L;
 
                @Override
@@ -278,4 +365,27 @@ public class InterruptSensitiveRestoreTest {
                        fail("should never be called");
                }
        }
+
+       private static class TestSource implements SourceFunction<Object>, 
CheckpointedFunction {
+               private static final long serialVersionUID = 1L;
+
+               @Override
+               public void run(SourceContext<Object> ctx) throws Exception {
+                       fail("should never be called");
+               }
+
+               @Override
+               public void cancel() {}
+
+
+               @Override
+               public void snapshotState(FunctionSnapshotContext context) 
throws Exception {
+                       fail("should never be called");
+               }
+
+               @Override
+               public void initializeState(FunctionInitializationContext 
context) throws Exception {
+                       
((StateInitializationContext)context).getRawOperatorStateInputs().iterator().next().getStream().read();
+               }
+       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
index 346d5c3..7fe4ebc 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
@@ -192,12 +192,17 @@ public class AbstractStreamOperatorTestHarness<OUT> {
                                        final StreamOperator<?> operator = 
(StreamOperator<?>) invocationOnMock.getArguments()[0];
                                        final Collection<OperatorStateHandle> 
stateHandles = (Collection<OperatorStateHandle>) 
invocationOnMock.getArguments()[1];
                                        OperatorStateBackend osb;
-                                       if (null == stateHandles) {
-                                               osb = 
stateBackend.createOperatorStateBackend(environment, 
operator.getClass().getSimpleName());
-                                       } else {
-                                               osb = 
stateBackend.restoreOperatorStateBackend(environment, 
operator.getClass().getSimpleName(), stateHandles);
-                                       }
+
+                                       osb = 
stateBackend.createOperatorStateBackend(
+                                                       environment,
+                                                       
operator.getClass().getSimpleName());
+
                                        
mockTask.getCancelables().registerClosable(osb);
+
+                                       if (null != stateHandles) {
+                                               osb.restore(stateHandles);
+                                       }
+
                                        return osb;
                                }
                        
}).when(mockTask).createOperatorStateBackend(any(StreamOperator.class), 
any(Collection.class));

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
index 3a47a1d..4abb6e2 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
@@ -100,33 +100,24 @@ public class KeyedOneInputStreamOperatorTestHarness<K, 
IN, OUT>
                                        final int numberOfKeyGroups = (Integer) 
invocationOnMock.getArguments()[1];
                                        final KeyGroupRange keyGroupRange = 
(KeyGroupRange) invocationOnMock.getArguments()[2];
 
-                                       if(keyedStateBackend != null) {
+                                       if (keyedStateBackend != null) {
                                                keyedStateBackend.dispose();
                                        }
 
-                                       if (restoredKeyedState == null) {
-                                               keyedStateBackend = 
stateBackend.createKeyedStateBackend(
-                                                               
mockTask.getEnvironment(),
-                                                               new JobID(),
-                                                               "test_op",
-                                                               keySerializer,
-                                                               
numberOfKeyGroups,
-                                                               keyGroupRange,
-                                                               
mockTask.getEnvironment().getTaskKvStateRegistry());
-                                               return keyedStateBackend;
-                                       } else {
-                                               keyedStateBackend = 
stateBackend.restoreKeyedStateBackend(
-                                                               
mockTask.getEnvironment(),
-                                                               new JobID(),
-                                                               "test_op",
-                                                               keySerializer,
-                                                               
numberOfKeyGroups,
-                                                               keyGroupRange,
-                                                               
restoredKeyedState,
-                                                               
mockTask.getEnvironment().getTaskKvStateRegistry());
-                                               restoredKeyedState = null;
-                                               return keyedStateBackend;
+                                       keyedStateBackend = 
stateBackend.createKeyedStateBackend(
+                                                       
mockTask.getEnvironment(),
+                                                       new JobID(),
+                                                       "test_op",
+                                                       keySerializer,
+                                                       numberOfKeyGroups,
+                                                       keyGroupRange,
+                                                       
mockTask.getEnvironment().getTaskKvStateRegistry());
+
+                                       if (restoredKeyedState != null) {
+                                               
keyedStateBackend.restore(restoredKeyedState);
                                        }
+
+                                       return keyedStateBackend;
                                }
                        
}).when(mockTask).createKeyedStateBackend(any(TypeSerializer.class), anyInt(), 
any(KeyGroupRange.class));
                } catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
index 0aa91d9..8e76f70 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
@@ -94,29 +94,18 @@ public class KeyedTwoInputStreamOperatorTestHarness<K, IN1, 
IN2, OUT>
                                                keyedStateBackend.close();
                                        }
 
-                                       if (restoredKeyedState == null) {
-                                               keyedStateBackend = 
stateBackend.createKeyedStateBackend(
-                                                               
mockTask.getEnvironment(),
-                                                               new JobID(),
-                                                               "test_op",
-                                                               keySerializer,
-                                                               
numberOfKeyGroups,
-                                                               keyGroupRange,
-                                                               
mockTask.getEnvironment().getTaskKvStateRegistry());
-                                               return keyedStateBackend;
-                                       } else {
-                                               keyedStateBackend = 
stateBackend.restoreKeyedStateBackend(
-                                                               
mockTask.getEnvironment(),
-                                                               new JobID(),
-                                                               "test_op",
-                                                               keySerializer,
-                                                               
numberOfKeyGroups,
-                                                               keyGroupRange,
-                                                               
restoredKeyedState,
-                                                               
mockTask.getEnvironment().getTaskKvStateRegistry());
-                                               restoredKeyedState = null;
-                                               return keyedStateBackend;
+                                       keyedStateBackend = 
stateBackend.createKeyedStateBackend(
+                                                       
mockTask.getEnvironment(),
+                                                       new JobID(),
+                                                       "test_op",
+                                                       keySerializer,
+                                                       numberOfKeyGroups,
+                                                       keyGroupRange,
+                                                       
mockTask.getEnvironment().getTaskKvStateRegistry());
+                                       if (restoredKeyedState != null) {
+                                               
keyedStateBackend.restore(restoredKeyedState);
                                        }
+                                       return keyedStateBackend;
                                }
                        
}).when(mockTask).createKeyedStateBackend(any(TypeSerializer.class), anyInt(), 
any(KeyGroupRange.class));
                } catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/flink/blob/39fc07f8/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
index 963d18a..0e62fbb 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
@@ -32,13 +32,11 @@ import 
org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase;
 import org.junit.Test;
 
 import java.io.IOException;
-import java.util.Collection;
 
 import static org.junit.Assert.fail;
 
@@ -110,19 +108,6 @@ public class StateBackendITCase extends 
StreamingMultipleProgramsTestBase {
                                TaskKvStateRegistry kvStateRegistry) throws 
Exception {
                        throw new SuccessException();
                }
-
-               @Override
-               public <K> AbstractKeyedStateBackend<K> 
restoreKeyedStateBackend(
-                               Environment env,
-                               JobID jobID,
-                               String operatorIdentifier,
-                               TypeSerializer<K> keySerializer,
-                               int numberOfKeyGroups,
-                               KeyGroupRange keyGroupRange,
-                               Collection<KeyGroupsStateHandle> restoredState,
-                               TaskKvStateRegistry kvStateRegistry) throws 
Exception {
-                       throw new SuccessException();
-               }
        }
 
        static final class SuccessException extends IOException {

Reply via email to