Repository: flink
Updated Branches:
  refs/heads/master 41f581822 -> e9f660d1f


[FLINK-3466] [runtime] Make state handles cancelable.

State handles are cancelable, to make sure long running checkpoint restore 
operations do
finish early on cancallation, even if the code does not properly react to 
interrupts.

This is especially important since HDFS client code is so buggy that it 
deadlocks when
interrupted without closing.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/e9f660d1
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/e9f660d1
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/e9f660d1

Branch: refs/heads/master
Commit: e9f660d1ff5540c7ef829f2de5bb870b787c18b7
Parents: 2837c60
Author: Stephan Ewen <[email protected]>
Authored: Wed Jul 13 21:32:40 2016 +0200
Committer: Stephan Ewen <[email protected]>
Committed: Fri Jul 15 17:18:04 2016 +0200

----------------------------------------------------------------------
 .../streaming/state/RocksDBStateBackend.java    |  17 +-
 .../org/apache/flink/util/ExceptionUtils.java   |  26 ++-
 .../runtime/state/AbstractCloseableHandle.java  | 127 +++++++++++
 .../runtime/state/AbstractStateBackend.java     |   5 +
 .../state/AsynchronousKvStateSnapshot.java      |   7 +
 .../runtime/state/GenericFoldingState.java      |   7 +
 .../flink/runtime/state/GenericListState.java   |   6 +
 .../runtime/state/GenericReducingState.java     |   7 +
 .../flink/runtime/state/KvStateSnapshot.java    |  21 +-
 .../flink/runtime/state/LocalStateHandle.java   |   4 +
 .../apache/flink/runtime/state/StateHandle.java |  22 +-
 .../apache/flink/runtime/state/StateObject.java |  54 +++++
 .../filesystem/AbstractFileStateHandle.java     |  13 +-
 .../filesystem/AbstractFsStateSnapshot.java     |  11 +-
 .../filesystem/FileSerializableStateHandle.java |   7 +-
 .../state/filesystem/FileStreamStateHandle.java |  24 +-
 .../state/memory/AbstractMemStateSnapshot.java  |  23 +-
 .../state/memory/ByteStreamStateHandle.java     |  27 ++-
 .../state/memory/SerializedStateHandle.java     |   4 +-
 .../messages/CheckpointMessagesTest.java        |   6 +-
 .../state/AbstractCloseableHandleTest.java      |  89 ++++++++
 .../ZooKeeperStateHandleStoreITCase.java        |   4 +
 .../operators/GenericWriteAheadSink.java        |  25 +++
 .../streaming/runtime/tasks/StreamTask.java     | 186 ++++++++++------
 .../runtime/tasks/StreamTaskState.java          |  62 +++++-
 .../runtime/tasks/StreamTaskStateList.java      |  27 ++-
 .../tasks/InterruptSensitiveRestoreTest.java    | 223 +++++++++++++++++++
 .../tasks/StreamTaskAsyncCheckpointTest.java    |  10 +-
 28 files changed, 912 insertions(+), 132 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
 
b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
index 4778aa0..9496d12 100644
--- 
a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
+++ 
b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
@@ -32,6 +32,7 @@ import java.util.Random;
 import java.util.UUID;
 
 import org.apache.commons.io.FileUtils;
+
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.state.FoldingState;
 import org.apache.flink.api.common.state.FoldingStateDescriptor;
@@ -56,12 +57,13 @@ import org.apache.flink.runtime.state.KvState;
 import org.apache.flink.runtime.state.KvStateSnapshot;
 import org.apache.flink.runtime.state.StateHandle;
 import org.apache.flink.api.common.state.StateBackend;
-
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.apache.flink.runtime.util.SerializableObject;
 import org.apache.flink.streaming.util.HDFSCopyFromLocal;
 import org.apache.flink.streaming.util.HDFSCopyToLocal;
+
 import org.apache.hadoop.fs.FileSystem;
+
 import org.rocksdb.BackupEngine;
 import org.rocksdb.BackupableDBOptions;
 import org.rocksdb.ColumnFamilyDescriptor;
@@ -74,6 +76,7 @@ import org.rocksdb.RestoreOptions;
 import org.rocksdb.RocksDB;
 import org.rocksdb.RocksDBException;
 import org.rocksdb.RocksIterator;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -684,6 +687,11 @@ public class RocksDBStateBackend extends 
AbstractStateBackend {
                        FileSystem fs = FileSystem.get(backupUri, 
HadoopFileSystem.getHadoopConfiguration());
                        return fs.getContentSummary(new 
org.apache.hadoop.fs.Path(backupUri)).getLength();
                }
+
+               @Override
+               public void close() throws IOException {
+                       // cannot do much here
+               }
        }
 
        // 
------------------------------------------------------------------------
@@ -797,7 +805,7 @@ public class RocksDBStateBackend extends 
AbstractStateBackend {
                 * Creates a new snapshot from the given state parameters.
                 */
                private FinalFullyAsyncSnapshot(StateHandle<DataInputView> 
stateHandle, long checkpointId) {
-                       this.stateHandle = stateHandle;
+                       this.stateHandle = requireNonNull(stateHandle);
                        this.checkpointId = checkpointId;
                }
 
@@ -818,6 +826,11 @@ public class RocksDBStateBackend extends 
AbstractStateBackend {
                public final long getStateSize() throws Exception {
                        return stateHandle.getStateSize();
                }
+
+               @Override
+               public void close() throws IOException {
+                       stateHandle.close();
+               }
        }
 
        // 
------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-core/src/main/java/org/apache/flink/util/ExceptionUtils.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/util/ExceptionUtils.java 
b/flink-core/src/main/java/org/apache/flink/util/ExceptionUtils.java
index b3c78f8..7227006 100644
--- a/flink-core/src/main/java/org/apache/flink/util/ExceptionUtils.java
+++ b/flink-core/src/main/java/org/apache/flink/util/ExceptionUtils.java
@@ -103,7 +103,7 @@ public final class ExceptionUtils {
         * (and RuntimeException and Error). Throws this exception directly, if 
it is an IOException,
         * a RuntimeException, or an Error. Otherwise does nothing.
         *
-        * @param t The throwable to be thrown.
+        * @param t The Throwable to be thrown.
         */
        public static void tryRethrowIOException(Throwable t) throws 
IOException {
                if (t instanceof IOException) {
@@ -118,6 +118,30 @@ public final class ExceptionUtils {
        }
 
        /**
+        * Re-throws the given {@code Throwable} in scenarios where the 
signatures allows only IOExceptions
+        * (and RuntimeException and Error).
+        * 
+        * Throws this exception directly, if it is an IOException, a 
RuntimeException, or an Error. Otherwise it 
+        * wraps it in an IOException and throws it.
+        * 
+        * @param t The Throwable to be thrown.
+        */
+       public static void rethrowIOException(Throwable t) throws IOException {
+               if (t instanceof IOException) {
+                       throw (IOException) t;
+               }
+               else if (t instanceof RuntimeException) {
+                       throw (RuntimeException) t;
+               }
+               else if (t instanceof Error) {
+                       throw (Error) t;
+               }
+               else {
+                       throw new IOException(t);
+               }
+       }
+
+       /**
         * Private constructor to prevent instantiation.
         */
        private ExceptionUtils() {

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractCloseableHandle.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractCloseableHandle.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractCloseableHandle.java
new file mode 100644
index 0000000..609158d
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractCloseableHandle.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
+
+/**
+ * A simple base for closable handles.
+ * 
+ * Offers to register a stream (or other closable object) that close calls are 
delegated to if
+ * the handle is closed or was already closed.
+ */
+public abstract class AbstractCloseableHandle implements Closeable, 
Serializable {
+
+       /** Serial Version UID must be constant to maintain format 
compatibility */
+       private static final long serialVersionUID = 1L;
+
+       /** To atomically update the "closable" field without needing to add a 
member class like "AtomicBoolean */
+       private static final AtomicIntegerFieldUpdater<AbstractCloseableHandle> 
CLOSER = 
+                       
AtomicIntegerFieldUpdater.newUpdater(AbstractCloseableHandle.class, "isClosed");
+
+       // 
------------------------------------------------------------------------
+
+       /** The closeable to close if this handle is closed late */ 
+       private transient volatile Closeable toClose;
+
+       /** Flag to remember if this handle was already closed */
+       @SuppressWarnings("unused") // this field is actually updated, but via 
the "CLOSER" updater
+       private transient volatile int isClosed;
+
+       // 
------------------------------------------------------------------------
+
+       protected final void registerCloseable(Closeable toClose) throws 
IOException {
+               if (toClose == null) {
+                       return;
+               }
+               
+               // NOTE: The order of operations matters here:
+               // (1) first setting the closeable
+               // (2) checking the flag.
+               // Because the order in the {@link #close()} method is the 
opposite, and
+               // both variables are volatile (reordering barriers), we can be 
sure that
+               // one of the methods always notices the effect of a concurrent 
call to the
+               // other method.
+
+               this.toClose = toClose;
+
+               // check if we were closed early
+               if (this.isClosed != 0) {
+                       toClose.close();
+                       throw new IOException("handle is closed");
+               }
+       }
+
+       /**
+        * Closes the handle.
+        * 
+        * <p>If a "Closeable" has been registered via {@link 
#registerCloseable(Closeable)},
+        * then this will be closes.
+        * 
+        * <p>If any "Closeable" will be registered via {@link 
#registerCloseable(Closeable)} in the future,
+        * it will immediately be closed and that method will throw an 
exception.
+        * 
+        * @throws IOException Exceptions occurring while closing an already 
registered {@code Closeable}
+        *                     are forwarded.
+        * 
+        * @see #registerCloseable(Closeable)
+        */
+       @Override
+       public final void close() throws IOException {
+               // NOTE: The order of operations matters here:
+               // (1) first setting the closed flag
+               // (2) checking whether there is already a closeable
+               // Because the order in the {@link 
#registerCloseable(Closeable)} method is the opposite, and
+               // both variables are volatile (reordering barriers), we can be 
sure that
+               // one of the methods always notices the effect of a concurrent 
call to the
+               // other method.
+
+               if (CLOSER.compareAndSet(this, 0, 1)) {
+                       final Closeable toClose = this.toClose;
+                       if (toClose != null) {
+                               this.toClose = null;
+                               toClose.close();
+                       }
+               }
+       }
+
+       /**
+        * Checks whether this handle has been closed.
+        * 
+        * @return True is the handle is closed, false otherwise.
+        */
+       public boolean isClosed() {
+               return isClosed != 0;
+       }
+
+       /**
+        * This method checks whether the handle is closed and throws an 
exception if it is closed.
+        * If the handle is not closed, this method does nothing.
+        * 
+        * @throws IOException Thrown, if the handle has been closed.
+        */
+       public void ensureNotClosed() throws IOException {
+               if (isClosed != 0) {
+                       throw new IOException("handle is closed");
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
index 6ab4999..b86688b 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
@@ -480,5 +480,10 @@ public abstract class AbstractStateBackend implements 
java.io.Serializable {
                public long getStateSize() throws Exception {
                        return stream.getStateSize();
                }
+
+               @Override
+               public void close() throws IOException {
+                       stream.close();
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsynchronousKvStateSnapshot.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsynchronousKvStateSnapshot.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsynchronousKvStateSnapshot.java
index 877034d..c2fc8a4 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsynchronousKvStateSnapshot.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsynchronousKvStateSnapshot.java
@@ -22,6 +22,8 @@ import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 
+import java.io.IOException;
+
 /**
  * {@link KvStateSnapshot} that asynchronously materializes the state that it 
represents. Instead
  * of representing a materialized handle to state this would normally hold the 
(immutable) state
@@ -58,4 +60,9 @@ public abstract class AsynchronousKvStateSnapshot<K, N, S 
extends State, SD exte
        public long getStateSize() throws Exception {
                throw new RuntimeException("This should never be called and 
probably points to a bug.");
        }
+
+       @Override
+       public void close() throws IOException {
+               throw new RuntimeException("This should never be called and 
probably points to a bug.");
+       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericFoldingState.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericFoldingState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericFoldingState.java
index 762cc3a..5f6600d 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericFoldingState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericFoldingState.java
@@ -24,6 +24,8 @@ import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 
+import java.io.IOException;
+
 /**
  * Generic implementation of {@link FoldingState} based on a wrapped {@link 
ValueState}.
  *
@@ -128,5 +130,10 @@ public class GenericFoldingState<K, N, T, ACC, Backend 
extends AbstractStateBack
                public long getStateSize() throws Exception {
                        return wrappedSnapshot.getStateSize();
                }
+
+               @Override
+               public void close() throws IOException {
+                       wrappedSnapshot.close();
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericListState.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericListState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericListState.java
index 9393082..3414855 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericListState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericListState.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 
@@ -133,5 +134,10 @@ public class GenericListState<K, N, T, Backend extends 
AbstractStateBackend, W e
                public long getStateSize() throws Exception {
                        return wrappedSnapshot.getStateSize();
                }
+
+               @Override
+               public void close() throws IOException {
+                       wrappedSnapshot.close();
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericReducingState.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericReducingState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericReducingState.java
index 7407dfa..9a2eb21 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericReducingState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericReducingState.java
@@ -24,6 +24,8 @@ import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 
+import java.io.IOException;
+
 /**
  * Generic implementation of {@link ReducingState} based on a wrapped {@link 
ValueState}.
  *
@@ -131,5 +133,10 @@ public class GenericReducingState<K, N, T, Backend extends 
AbstractStateBackend,
                public long getStateSize() throws Exception {
                        return wrappedSnapshot.getStateSize();
                }
+
+               @Override
+               public void close() throws IOException {
+                       wrappedSnapshot.close();
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvStateSnapshot.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvStateSnapshot.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvStateSnapshot.java
index 847d53e..5654845 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvStateSnapshot.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvStateSnapshot.java
@@ -39,7 +39,8 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
  * @param <SD> The type of the {@link StateDescriptor}
  * @param <Backend> The type of the backend that can restore the state from 
this snapshot.
  */
-public interface KvStateSnapshot<K, N, S extends State, SD extends 
StateDescriptor<S, ?>, Backend extends AbstractStateBackend> extends 
java.io.Serializable {
+public interface KvStateSnapshot<K, N, S extends State, SD extends 
StateDescriptor<S, ?>, Backend extends AbstractStateBackend> 
+               extends StateObject {
 
        /**
         * Loads the key/value state back from this snapshot.
@@ -57,22 +58,4 @@ public interface KvStateSnapshot<K, N, S extends State, SD 
extends StateDescript
                Backend stateBackend,
                TypeSerializer<K> keySerializer,
                ClassLoader classLoader) throws Exception;
-
-       /**
-        * Discards the state snapshot, removing any resources occupied by it.
-        * 
-        * @throws Exception Exceptions occurring during the state disposal 
should be forwarded.
-        */
-       void discardState() throws Exception;
-
-       /**
-        * Returns the size of the state in bytes.
-        *
-        * <p>If the the size is not known, return <code>0</code>.
-        *
-        * @return Size of the state in bytes.
-        *
-        * @throws Exception If the operation fails during size retrieval.
-        */
-       long getStateSize() throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/LocalStateHandle.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/LocalStateHandle.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/LocalStateHandle.java
index 4e60ab6..4e7531f 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/LocalStateHandle.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/LocalStateHandle.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state;
 
+import java.io.IOException;
 import java.io.Serializable;
 
 /**
@@ -46,4 +47,7 @@ public class LocalStateHandle<T extends Serializable> 
implements StateHandle<T>
        public long getStateSize() {
                return 0;
        }
+
+       @Override
+       public void close() throws IOException {}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateHandle.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateHandle.java 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateHandle.java
index 800e34d..b736252 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateHandle.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateHandle.java
@@ -18,15 +18,12 @@
 
 package org.apache.flink.runtime.state;
 
-
-import java.io.Serializable;
-
 /**
  * StateHandle is a general handle interface meant to abstract operator state 
fetching. 
  * A StateHandle implementation can for example include the state itself in 
cases where the state 
  * is lightweight or fetching it lazily from some external storage when the 
state is too large.
  */
-public interface StateHandle<T> extends Serializable {
+public interface StateHandle<T> extends StateObject {
 
        /**
         * This retrieves and return the state represented by the handle.
@@ -37,21 +34,4 @@ public interface StateHandle<T> extends Serializable {
         * @throws java.lang.Exception Thrown, if the state cannot be fetched.
         */
        T getState(ClassLoader userCodeClassLoader) throws Exception;
-       
-       /**
-        * Discards the state referred to by this handle, to free up resources 
in
-        * the persistent storage. This method is called when the handle will 
not be
-        * used any more.
-        */
-       void discardState() throws Exception;
-
-       /**
-        * Returns the size of the state in bytes.
-        *
-        * <p>If the the size is not known, return <code>0</code>.
-        *
-        * @return Size of the state in bytes.
-        * @throws Exception If the operation fails during size retrieval.
-        */
-       long getStateSize() throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
new file mode 100644
index 0000000..a43a2c5
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+/**
+ * Base of all types that represent checkpointed state. Specializations are for
+ * example {@link StateHandle StateHandles} (directly resolve to state) and 
+ * {@link KvStateSnapshot key/value state snapshots}.
+ * 
+ * <p>State objects define how to:
+ * <ul>
+ *     <li><b>Discard State</b>: The {@link #discardState()} method defines 
how state is permanently
+ *         disposed/deleted. After that method call, state may not be 
recoverable any more.</li>
+ 
+ *     <li><b>Close the current state access</b>: The {@link #close()} method 
defines how to
+ *         stop the current access or recovery to the state. Called for 
example when an operation is
+ *         canceled during recovery.</li>
+ * </ul>
+ */
+public interface StateObject extends java.io.Closeable, java.io.Serializable {
+
+       /**
+        * Discards the state referred to by this handle, to free up resources 
in
+        * the persistent storage. This method is called when the handle will 
not be
+        * used any more.
+        */
+       void discardState() throws Exception;
+
+       /**
+        * Returns the size of the state in bytes.
+        *
+        * <p>If the the size is not known, return {@code 0}.
+        *
+        * @return Size of the state in bytes.
+        * @throws Exception If the operation fails during size retrieval.
+        */
+       long getStateSize() throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/AbstractFileStateHandle.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/AbstractFileStateHandle.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/AbstractFileStateHandle.java
index 00800b2..0585062 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/AbstractFileStateHandle.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/AbstractFileStateHandle.java
@@ -20,18 +20,20 @@ package org.apache.flink.runtime.state.filesystem;
 
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.state.AbstractCloseableHandle;
+import org.apache.flink.runtime.state.StateObject;
 
 import java.io.IOException;
 
-import static java.util.Objects.requireNonNull;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
  * Base class for state that is stored in a file.
  */
-public abstract class AbstractFileStateHandle implements java.io.Serializable {
-       
+public abstract class AbstractFileStateHandle extends AbstractCloseableHandle 
implements StateObject {
+
        private static final long serialVersionUID = 350284443258002355L;
-       
+
        /** The path to the file in the filesystem, fully describing the file 
system */
        private final Path filePath;
 
@@ -44,7 +46,7 @@ public abstract class AbstractFileStateHandle implements 
java.io.Serializable {
         * @param filePath The path to the file that stores the state.
         */
        protected AbstractFileStateHandle(Path filePath) {
-               this.filePath = requireNonNull(filePath);
+               this.filePath = checkNotNull(filePath);
        }
 
        /**
@@ -61,6 +63,7 @@ public abstract class AbstractFileStateHandle implements 
java.io.Serializable {
         * 
         * @throws Exception Thrown, if the file deletion (not the directory 
deletion) fails.
         */
+       @Override
        public void discardState() throws Exception {
                getFileSystem().delete(filePath, false);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/AbstractFsStateSnapshot.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/AbstractFsStateSnapshot.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/AbstractFsStateSnapshot.java
index cd02870..0692541 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/AbstractFsStateSnapshot.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/AbstractFsStateSnapshot.java
@@ -27,7 +27,6 @@ import 
org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.runtime.state.KvState;
 import org.apache.flink.runtime.state.KvStateSnapshot;
 
-import java.io.DataInputStream;
 import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
@@ -39,7 +38,8 @@ import java.util.Map;
  * @param <N> The type of the namespace in the snapshot state.
  * @param <SV> The type of the state value.
  */
-public abstract class AbstractFsStateSnapshot<K, N, SV, S extends State, SD 
extends StateDescriptor<S, ?>> extends AbstractFileStateHandle implements 
KvStateSnapshot<K, N, S, SD, FsStateBackend> {
+public abstract class AbstractFsStateSnapshot<K, N, SV, S extends State, SD 
extends StateDescriptor<S, ?>> 
+               extends AbstractFileStateHandle implements KvStateSnapshot<K, 
N, S, SD, FsStateBackend> {
 
        private static final long serialVersionUID = 1L;
 
@@ -95,9 +95,13 @@ public abstract class AbstractFsStateSnapshot<K, N, SV, S 
extends State, SD exte
                }
 
                // state restore
+               ensureNotClosed();
+
                try (FSDataInputStream inStream = 
stateBackend.getFileSystem().open(getFilePath())) {
-                       DataInputViewStreamWrapper inView = new 
DataInputViewStreamWrapper(new DataInputStream(inStream));
+                       // make sure the in-progress restore from the handle 
can be closed 
+                       registerCloseable(inStream);
 
+                       DataInputViewStreamWrapper inView = new 
DataInputViewStreamWrapper(inStream);
 
                        final int numKeys = inView.readInt();
                        HashMap<N, Map<K, SV>> stateMap = new 
HashMap<>(numKeys);
@@ -114,7 +118,6 @@ public abstract class AbstractFsStateSnapshot<K, N, SV, S 
extends State, SD exte
                                }
                        }
 
-//                     return new FsHeapValueState<>(stateBackend, 
keySerializer, namespaceSerializer, stateDesc, stateMap);
                        return createFsState(stateBackend, stateMap);
                }
                catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileSerializableStateHandle.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileSerializableStateHandle.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileSerializableStateHandle.java
index 662678e..34a1cb0 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileSerializableStateHandle.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileSerializableStateHandle.java
@@ -35,7 +35,7 @@ import java.io.Serializable;
 public class FileSerializableStateHandle<T extends Serializable> extends 
AbstractFileStateHandle implements StateHandle<T> {
 
        private static final long serialVersionUID = -657631394290213622L;
-       
+
        /**
         * Creates a new FileSerializableStateHandle pointing to state at the 
given file path.
         * 
@@ -48,7 +48,12 @@ public class FileSerializableStateHandle<T extends 
Serializable> extends Abstrac
        @Override
        @SuppressWarnings("unchecked")
        public T getState(ClassLoader classLoader) throws Exception {
+               ensureNotClosed();
+
                try (FSDataInputStream inStream = 
getFileSystem().open(getFilePath())) {
+                       // make sure any deserialization can be aborted
+                       registerCloseable(inStream);
+
                        ObjectInputStream ois = new 
InstantiationUtil.ClassLoaderObjectInputStream(inStream, classLoader);
                        return (T) ois.readObject();
                }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStreamStateHandle.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStreamStateHandle.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStreamStateHandle.java
index be9c4cd..5bfb4ee 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStreamStateHandle.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStreamStateHandle.java
@@ -30,7 +30,7 @@ import java.io.Serializable;
  * A state handle that points to state in a file system, accessible as an 
input stream.
  */
 public class FileStreamStateHandle extends AbstractFileStateHandle implements 
StreamStateHandle {
-       
+
        private static final long serialVersionUID = -6826990484549987311L;
 
        /**
@@ -44,7 +44,13 @@ public class FileStreamStateHandle extends 
AbstractFileStateHandle implements St
 
        @Override
        public InputStream getState(ClassLoader userCodeClassLoader) throws 
Exception {
-               return getFileSystem().open(getFilePath());
+               ensureNotClosed();
+
+               InputStream inStream = getFileSystem().open(getFilePath());
+               // make sure the state handle is cancelable
+               registerCloseable(inStream);
+
+               return inStream; 
        }
 
        /**
@@ -60,6 +66,18 @@ public class FileStreamStateHandle extends 
AbstractFileStateHandle implements St
 
        @Override
        public <T extends Serializable> StateHandle<T> toSerializableHandle() {
-               return new FileSerializableStateHandle<>(getFilePath());
+               FileSerializableStateHandle<T> handle = new 
FileSerializableStateHandle<>(getFilePath());
+
+               // forward closed status
+               if (isClosed()) {
+                       try {
+                               handle.close();
+                       } catch (IOException e) {
+                               // should not happen on a fresh handle, but 
forward anyways
+                               throw new RuntimeException(e);
+                       }
+               }
+
+               return handle;
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/AbstractMemStateSnapshot.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/AbstractMemStateSnapshot.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/AbstractMemStateSnapshot.java
index 86d4c7d..e1b62d2 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/AbstractMemStateSnapshot.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/AbstractMemStateSnapshot.java
@@ -25,6 +25,7 @@ import org.apache.flink.runtime.state.KvState;
 import org.apache.flink.runtime.state.KvStateSnapshot;
 import org.apache.flink.runtime.util.DataInputDeserializer;
 
+import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -36,7 +37,8 @@ import java.util.Map;
  * @param <N> The type of the namespace in the snapshot state.
  * @param <SV> The type of the value in the snapshot state.
  */
-public abstract class AbstractMemStateSnapshot<K, N, SV, S extends State, SD 
extends StateDescriptor<S, ?>> implements KvStateSnapshot<K, N, S, SD, 
MemoryStateBackend> {
+public abstract class AbstractMemStateSnapshot<K, N, SV, S extends State, SD 
extends StateDescriptor<S, ?>> 
+               implements KvStateSnapshot<K, N, S, SD, MemoryStateBackend> {
 
        private static final long serialVersionUID = 1L;
 
@@ -54,6 +56,8 @@ public abstract class AbstractMemStateSnapshot<K, N, SV, S 
extends State, SD ext
 
        /** The serialized data of the state key/value pairs */
        private final byte[] data;
+       
+       private transient boolean closed;
 
        /**
         * Creates a new heap memory state snapshot.
@@ -92,14 +96,18 @@ public abstract class AbstractMemStateSnapshot<K, N, SV, S 
extends State, SD ext
                                        "(" + this.keySerializer + ") " +
                                        "now is (" + keySerializer + ")");
                }
-               
+
+               if (closed) {
+                       throw new IOException("snapshot has been closed");
+               }
+
                // restore state
                DataInputDeserializer inView = new DataInputDeserializer(data, 
0, data.length);
 
                final int numKeys = inView.readInt();
                HashMap<N, Map<K, SV>> stateMap = new HashMap<>(numKeys);
 
-               for (int i = 0; i < numKeys; i++) {
+               for (int i = 0; i < numKeys && !closed; i++) {
                        N namespace = namespaceSerializer.deserialize(inView);
                        final int numValues = inView.readInt();
                        Map<K, SV> namespaceMap = new HashMap<>(numValues);
@@ -111,6 +119,10 @@ public abstract class AbstractMemStateSnapshot<K, N, SV, S 
extends State, SD ext
                        }
                }
 
+               if (closed) {
+                       throw new IOException("snapshot has been closed");
+               }
+
                return createMemState(stateMap);
        }
 
@@ -124,4 +136,9 @@ public abstract class AbstractMemStateSnapshot<K, N, SV, S 
extends State, SD ext
        public long getStateSize() {
                return data.length;
        }
+
+       @Override
+       public void close() {
+               closed = true;
+       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
index 61473ea..ba6de42 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
@@ -18,17 +18,19 @@
 
 package org.apache.flink.runtime.state.memory;
 
+import org.apache.flink.runtime.state.AbstractCloseableHandle;
 import org.apache.flink.runtime.state.StateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 
 import java.io.ByteArrayInputStream;
+import java.io.IOException;
 import java.io.InputStream;
 import java.io.Serializable;
 
 /**
  * A state handle that contains stream state in a byte array.
  */
-public final class ByteStreamStateHandle implements StreamStateHandle {
+public final class ByteStreamStateHandle extends AbstractCloseableHandle 
implements StreamStateHandle {
 
        private static final long serialVersionUID = -5280226231200217594L;
        
@@ -45,8 +47,13 @@ public final class ByteStreamStateHandle implements 
StreamStateHandle {
        }
 
        @Override
-       public InputStream getState(ClassLoader userCodeClassLoader) {
-               return new ByteArrayInputStream(data);
+       public InputStream getState(ClassLoader userCodeClassLoader) throws 
Exception {
+               ensureNotClosed();
+
+               ByteArrayInputStream stream = new ByteArrayInputStream(data);
+               registerCloseable(stream);
+
+               return stream;
        }
 
        @Override
@@ -59,6 +66,18 @@ public final class ByteStreamStateHandle implements 
StreamStateHandle {
 
        @Override
        public <T extends Serializable> StateHandle<T> toSerializableHandle() {
-               return new SerializedStateHandle<T>(data);
+               SerializedStateHandle<T> serializableHandle = new 
SerializedStateHandle<T>(data);
+
+               // forward the closed status
+               if (isClosed()) {
+                       try {
+                               serializableHandle.close();
+                       } catch (IOException e) {
+                               // should not happen on a fresh handle, but 
forward anyways
+                               throw new RuntimeException(e);
+                       }
+               }
+
+               return serializableHandle;
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/SerializedStateHandle.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/SerializedStateHandle.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/SerializedStateHandle.java
index 9aef733..4420470 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/SerializedStateHandle.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/SerializedStateHandle.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state.memory;
 
+import org.apache.flink.runtime.state.AbstractCloseableHandle;
 import org.apache.flink.runtime.state.StateHandle;
 import org.apache.flink.util.InstantiationUtil;
 
@@ -29,7 +30,7 @@ import java.io.Serializable;
  *
  * @param <T> The type of state represented by this state handle.
  */
-public class SerializedStateHandle<T extends Serializable> implements 
StateHandle<T> {
+public class SerializedStateHandle<T extends Serializable> extends 
AbstractCloseableHandle implements StateHandle<T> {
        
        private static final long serialVersionUID = 4145685722538475769L;
 
@@ -61,6 +62,7 @@ public class SerializedStateHandle<T extends Serializable> 
implements StateHandl
                        throw new NullPointerException();
                }
 
+               ensureNotClosed();
                return serializedData == null ? null : 
InstantiationUtil.<T>deserializeObject(serializedData, classLoader);
        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
index afd2405..73bf204 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
@@ -98,12 +98,14 @@ public class CheckpointMessagesTest {
                }
 
                @Override
-               public void discardState() throws Exception {
-               }
+               public void discardState() throws Exception {}
 
                @Override
                public long getStateSize() {
                        return 0;
                }
+
+               @Override
+               public void close() throws IOException {}
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java
new file mode 100644
index 0000000..ad3339a
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.junit.Test;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+public class AbstractCloseableHandleTest {
+
+       @Test
+       public void testRegisterThenClose() throws Exception {
+               Closeable closeable = mock(Closeable.class);
+
+               AbstractCloseableHandle handle = new CloseableHandle();
+               assertFalse(handle.isClosed());
+
+               // no immediate closing
+               handle.registerCloseable(closeable);
+               verify(closeable, times(0)).close();
+               assertFalse(handle.isClosed());
+
+               // close forwarded once
+               handle.close();
+               verify(closeable, times(1)).close();
+               assertTrue(handle.isClosed());
+
+               // no repeated closing
+               handle.close();
+               verify(closeable, times(1)).close();
+               assertTrue(handle.isClosed());
+       }
+
+       @Test
+       public void testCloseThenRegister() throws Exception {
+               Closeable closeable = mock(Closeable.class);
+
+               AbstractCloseableHandle handle = new CloseableHandle();
+               assertFalse(handle.isClosed());
+
+               // close the handle before setting the closeable
+               handle.close();
+               assertTrue(handle.isClosed());
+
+               // immediate closing
+               try {
+                       handle.registerCloseable(closeable);
+                       fail("this should throw an excepion");
+               } catch (IOException e) {
+                       // expected
+                       assertTrue(e.getMessage().contains("closed"));
+               }
+
+               // should still have called "close" on the Closeable
+               verify(closeable, times(1)).close();
+               assertTrue(handle.isClosed());
+
+               // no repeated closing
+               handle.close();
+               verify(closeable, times(1)).close();
+               assertTrue(handle.isClosed());
+       }
+
+       // 
------------------------------------------------------------------------
+
+       private static final class CloseableHandle extends 
AbstractCloseableHandle {
+               private static final long serialVersionUID = 1L;
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
index e166ed5..7505bfc 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
@@ -33,6 +33,7 @@ import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
@@ -586,5 +587,8 @@ public class ZooKeeperStateHandleStoreITCase extends 
TestLogger {
                public int getNumberOfDiscardCalls() {
                        return numberOfDiscardCalls;
                }
+
+               @Override
+               public void close() throws IOException {}
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
index b268c7a..23cfc3a 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
@@ -29,6 +29,7 @@ import 
org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskState;
+import org.apache.flink.util.ExceptionUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -218,6 +219,30 @@ public abstract class GenericWriteAheadSink<IN> extends 
AbstractStreamOperator<I
                        return stateSize;
                }
 
+               @Override
+               public void close() throws IOException {
+                       Throwable exception = null;
+
+                       for (Tuple2<Long, StateHandle<DataInputView>> pair : 
pendingHandles.values()) {
+                               StateHandle<DataInputView> handle = pair.f1;
+                               if (handle != null) {
+                                       try {
+                                               handle.close();
+                                       }
+                                       catch (Throwable t) {
+                                               if (exception != null) {
+                                                       exception = t;
+                                               }
+                                       }
+                               }
+                       }
+
+                       if (exception != null) {
+                               ExceptionUtils.rethrowIOException(exception);
+                       }
+               }
+
+               @Override
                public String toString() {
                        return this.pendingHandles.toString();
                }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 6ad94b4..940f699 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -45,13 +45,16 @@ import 
org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.runtime.io.RecordWriterOutput;
 import org.apache.flink.streaming.runtime.operators.Triggerable;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.Closeable;
 import java.io.Serializable;
-import java.util.Collections;
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ScheduledFuture;
@@ -147,7 +150,7 @@ public abstract class StreamTask<OUT, Operator extends 
StreamOperator<OUT>>
        private volatile AsynchronousException asyncException;
 
        /** The currently active background materialization threads */
-       private final Set<Thread> asyncCheckpointThreads = 
Collections.synchronizedSet(new HashSet<Thread>());
+       private final Set<Closeable> cancelables = new HashSet<Closeable>();
        
        /** Flag to mark the task "in operation", in which case check
         * needs to be initialized to true, so that early cancel() before 
invoke() behaves correctly */
@@ -244,7 +247,7 @@ public abstract class StreamTask<OUT, Operator extends 
StreamOperator<OUT>>
                        LOG.debug("Invoking {}", getName());
 
                        // first order of business is to give operators back 
their state
-                       restoreState();
+                       restoreState(lazyRestoreState);
                        lazyRestoreState = null; // GC friendliness
                        
                        // we need to make sure that any triggers scheduled in 
open() cannot be
@@ -303,10 +306,7 @@ public abstract class StreamTask<OUT, Operator extends 
StreamOperator<OUT>>
                        
                        // stop all asynchronous checkpoint threads
                        try {
-                               for (Thread checkpointThread : 
asyncCheckpointThreads) {
-                                       checkpointThread.interrupt();
-                               }
-                               asyncCheckpointThreads.clear();
+                               closeAllClosables();
                        }
                        catch (Throwable t) {
                                // catch and log the exception to not replace 
the original exception
@@ -352,6 +352,7 @@ public abstract class StreamTask<OUT, Operator extends 
StreamOperator<OUT>>
                isRunning = false;
                canceled = true;
                cancelTask();
+               closeAllClosables();
        }
 
        public final boolean isRunning() {
@@ -447,8 +448,28 @@ public abstract class StreamTask<OUT, Operator extends 
StreamOperator<OUT>>
                        timerService.shutdownService();
                }
 
-               for (Thread checkpointThread : asyncCheckpointThreads) {
-                       checkpointThread.interrupt();
+               closeAllClosables();
+       }
+
+       private void closeAllClosables() {
+               // first, create a copy of the cancelables to prevent 
concurrent modifications
+               // and to not hold the lock for too long. the copy can be a 
cheap list
+               List<Closeable> localCancelables = null;
+               synchronized (cancelables) {
+                       if (cancelables.size() > 0) {
+                               localCancelables = new ArrayList<>(cancelables);
+                               cancelables.clear();
+                       }
+               }
+
+               if (localCancelables != null) {
+                       for (Closeable cancelable : localCancelables) {
+                               try {
+                                       cancelable.close();
+                               } catch (Throwable t) {
+                                       LOG.error("Error on canceling 
operation", t);
+                               }
+                       }
                }
        }
 
@@ -502,16 +523,17 @@ public abstract class StreamTask<OUT, Operator extends 
StreamOperator<OUT>>
                lazyRestoreState = initialState;
        }
 
-       private void restoreState() throws Exception {
-               if (lazyRestoreState != null) {
+       private void restoreState(StreamTaskStateList restoredState) throws 
Exception {
+               if (restoredState != null) {
                        LOG.info("Restoring checkpointed state to task {}", 
getName());
                        
+                       synchronized (cancelables) {
+                               cancelables.add(restoredState);
+                       }
+
                        try {
                                final StreamOperator<?>[] allOperators = 
operatorChain.getAllOperators();
-                               final StreamTaskState[] states = 
lazyRestoreState.getState(userClassLoader);
-                               
-                               // be GC friendly
-                               lazyRestoreState = null;
+                               final StreamTaskState[] states = 
restoredState.getState(userClassLoader);
                                
                                for (int i = 0; i < states.length; i++) {
                                        StreamTaskState state = states[i];
@@ -529,6 +551,11 @@ public abstract class StreamTask<OUT, Operator extends 
StreamOperator<OUT>>
                        catch (Exception e) {
                                throw new Exception("Could not restore 
checkpointed state to operators and functions", e);
                        }
+                       finally {
+                               synchronized (cancelables) {
+                                       cancelables.remove(restoredState);
+                               }
+                       }
                }
        }
 
@@ -603,54 +630,13 @@ public abstract class StreamTask<OUT, Operator extends 
StreamOperator<OUT>>
                                } else {
                                        // start a Thread that does the 
asynchronous materialization and
                                        // then sends the checkpoint acknowledge
-
                                        String threadName = "Materialize 
checkpoint state " + checkpointId + " - " + getName();
-                                       Thread checkpointThread = new 
Thread(threadName) {
-                                               @Override
-                                               public void run() {
-                                                       try {
-                                                               for 
(StreamTaskState state : states) {
-                                                                       if 
(state != null) {
-                                                                               
if (state.getFunctionState() instanceof AsynchronousStateHandle) {
-                                                                               
        AsynchronousStateHandle<Serializable> asyncState = 
(AsynchronousStateHandle<Serializable>) state.getFunctionState();
-                                                                               
        state.setFunctionState(asyncState.materialize());
-                                                                               
}
-                                                                               
if (state.getOperatorState() instanceof AsynchronousStateHandle) {
-                                                                               
        AsynchronousStateHandle<?> asyncState = (AsynchronousStateHandle<?>) 
state.getOperatorState();
-                                                                               
        state.setOperatorState(asyncState.materialize());
-                                                                               
}
-                                                                               
if (state.getKvStates() != null) {
-                                                                               
        Set<String> keys = state.getKvStates().keySet();
-                                                                               
        HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> kvStates = 
state.getKvStates();
-                                                                               
        for (String key: keys) {
-                                                                               
                if (kvStates.get(key) instanceof AsynchronousKvStateSnapshot) {
-                                                                               
                        AsynchronousKvStateSnapshot<?, ?, ?, ?, ?> asyncHandle 
= (AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) kvStates.get(key);
-                                                                               
                        kvStates.put(key, asyncHandle.materialize());
-                                                                               
                }
-                                                                               
        }
-                                                                               
}
-
-                                                                       }
-                                                               }
-                                                               
StreamTaskStateList allStates = new StreamTaskStateList(states);
-                                                               
StreamTask.this.lastCheckpointSize = allStates.getStateSize();
-                                                               
getEnvironment().acknowledgeCheckpoint(checkpointId, allStates);
-                                                               
LOG.debug("Finished asynchronous checkpoints for checkpoint {} on task {}", 
checkpointId, getName());
-                                                       }
-                                                       catch (Exception e) {
-                                                               if 
(isRunning()) {
-                                                                       
LOG.error("Caught exception while materializing asynchronous checkpoints.", e);
-                                                               }
-                                                               if 
(asyncException == null) {
-                                                                       
asyncException = new AsynchronousException(e);
-                                                               }
-                                                       }
-                                                       
asyncCheckpointThreads.remove(this);
-                                               }
-                                       };
+                                       AsyncCheckpointThread checkpointThread 
= new AsyncCheckpointThread(
+                                                       threadName, this, 
cancelables, states, checkpointId);
 
-                                       
asyncCheckpointThreads.add(checkpointThread);
-                                       checkpointThread.setDaemon(true);
+                                       synchronized (cancelables) {
+                                               
cancelables.add(checkpointThread);
+                                       }
                                        checkpointThread.start();
                                }
                                return true;
@@ -784,7 +770,7 @@ public abstract class StreamTask<OUT, Operator extends 
StreamOperator<OUT>>
                        }
                };
        }
-       
+
        // 
------------------------------------------------------------------------
 
        /**
@@ -820,4 +806,80 @@ public abstract class StreamTask<OUT, Operator extends 
StreamOperator<OUT>>
                        }
                }
        }
+
+       // 
------------------------------------------------------------------------
+       
+       private static class AsyncCheckpointThread extends Thread implements 
Closeable {
+
+               private final StreamTask<?, ?> owner;
+
+               private final Set<Closeable> cancelables;
+
+               private final StreamTaskState[] states;
+
+               private final long checkpointId;
+
+               AsyncCheckpointThread(String name, StreamTask<?, ?> owner, 
Set<Closeable> cancelables,
+                               StreamTaskState[] states, long checkpointId) {
+                       super(name);
+                       setDaemon(true);
+
+                       this.owner = owner;
+                       this.cancelables = cancelables;
+                       this.states = states;
+                       this.checkpointId = checkpointId;
+               }
+
+               @Override
+               public void run() {
+                       try {
+                               for (StreamTaskState state : states) {
+                                       if (state != null) {
+                                               if (state.getFunctionState() 
instanceof AsynchronousStateHandle) {
+                                                       
AsynchronousStateHandle<Serializable> asyncState = 
(AsynchronousStateHandle<Serializable>) state.getFunctionState();
+                                                       
state.setFunctionState(asyncState.materialize());
+                                               }
+                                               if (state.getOperatorState() 
instanceof AsynchronousStateHandle) {
+                                                       
AsynchronousStateHandle<?> asyncState = (AsynchronousStateHandle<?>) 
state.getOperatorState();
+                                                       
state.setOperatorState(asyncState.materialize());
+                                               }
+                                               if (state.getKvStates() != 
null) {
+                                                       Set<String> keys = 
state.getKvStates().keySet();
+                                                       HashMap<String, 
KvStateSnapshot<?, ?, ?, ?, ?>> kvStates = state.getKvStates();
+                                                       for (String key: keys) {
+                                                               if 
(kvStates.get(key) instanceof AsynchronousKvStateSnapshot) {
+                                                                       
AsynchronousKvStateSnapshot<?, ?, ?, ?, ?> asyncHandle = 
(AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) kvStates.get(key);
+                                                                       
kvStates.put(key, asyncHandle.materialize());
+                                                               }
+                                                       }
+                                               }
+
+                                       }
+                               }
+                               StreamTaskStateList allStates = new 
StreamTaskStateList(states);
+                               owner.lastCheckpointSize = 
allStates.getStateSize();
+                               
owner.getEnvironment().acknowledgeCheckpoint(checkpointId, allStates);
+
+                               LOG.debug("Finished asynchronous checkpoints 
for checkpoint {} on task {}", checkpointId, getName());
+                       }
+                       catch (Exception e) {
+                               if (owner.isRunning()) {
+                                       LOG.error("Caught exception while 
materializing asynchronous checkpoints.", e);
+                               }
+                               if (owner.asyncException == null) {
+                                       owner.asyncException = new 
AsynchronousException(e);
+                               }
+                       }
+                       finally {
+                               synchronized (cancelables) {
+                                       cancelables.remove(this);
+                               }
+                       }
+               }
+
+               @Override
+               public void close() {
+                       interrupt();
+               }
+       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskState.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskState.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskState.java
index c9e29d3..925dd8c 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskState.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskState.java
@@ -21,7 +21,10 @@ package org.apache.flink.streaming.runtime.tasks;
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.runtime.state.StateHandle;
 import org.apache.flink.runtime.state.KvStateSnapshot;
+import org.apache.flink.util.ExceptionUtils;
 
+import java.io.Closeable;
+import java.io.IOException;
 import java.io.Serializable;
 import java.util.ConcurrentModificationException;
 import java.util.HashMap;
@@ -37,7 +40,7 @@ import java.util.Iterator;
  * </ul>
  */
 @Internal
-public class StreamTaskState implements Serializable {
+public class StreamTaskState implements Serializable, Closeable {
 
        private static final long serialVersionUID = 1L;
        
@@ -122,4 +125,61 @@ public class StreamTaskState implements Serializable {
                        }
                }
        }
+
+       @Override
+       public void close() throws IOException {
+               StateHandle<?> operatorState = this.operatorState;
+               StateHandle<?> functionState = this.functionState;
+               HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> kvStates = 
this.kvStates;
+
+               this.operatorState = null;
+               this.functionState = null;
+               this.kvStates = null;
+
+               Throwable firstException = null;
+
+               if (operatorState != null) {
+                       try {
+                               operatorState.close();
+                       } catch (Throwable t) {
+                               firstException = t;
+                       }
+               }
+
+               if (functionState != null) {
+                       try {
+                               functionState.close();
+                       } catch (Throwable t) {
+                               if (firstException == null) {
+                                       firstException = t;
+                               }
+                       }
+               }
+       
+               if (kvStates != null) {
+                       while (kvStates.size() > 0) {
+                               try {
+                                       Iterator<KvStateSnapshot<?, ?, ?, ?, 
?>> values = kvStates.values().iterator();
+                                       while (values.hasNext()) {
+                                               KvStateSnapshot<?, ?, ?, ?, ?> 
s = values.next();
+                                               try {
+                                                       s.close();
+                                               } catch (Throwable t) {
+                                                       if (firstException == 
null) {
+                                                               firstException 
= t;
+                                                       }
+                                               }
+                                               values.remove();
+                                       }
+                               }
+                               catch (ConcurrentModificationException e) {
+                                       // fall through the loop
+                               }
+                       }
+               }
+
+               if (firstException != null) {
+                       ExceptionUtils.rethrowIOException(firstException);
+               }
+       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskStateList.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskStateList.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskStateList.java
index d8a5b2f..ae85d86 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskStateList.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTaskStateList.java
@@ -21,7 +21,9 @@ package org.apache.flink.streaming.runtime.tasks;
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.runtime.state.KvStateSnapshot;
 import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.util.ExceptionUtils;
 
+import java.io.IOException;
 import java.util.HashMap;
 
 /**
@@ -47,7 +49,7 @@ public class StreamTaskStateList implements 
StateHandle<StreamTaskState[]> {
                }
                return true;
        }
-       
+
        @Override
        public StreamTaskState[] getState(ClassLoader userCodeClassLoader) {
                return states;
@@ -95,4 +97,27 @@ public class StreamTaskStateList implements 
StateHandle<StreamTaskState[]> {
                // State size as sum of all state sizes
                return sumStateSize;
        }
+
+       @Override
+       public void close() throws IOException {
+               if (states != null) {
+                       Throwable firstException = null;
+
+                       for (StreamTaskState state : states) {
+                               if (state != null) {
+                                       try {
+                                               state.close();
+                                       } catch (Throwable t) {
+                                               if (firstException == null) {
+                                                       firstException = t;
+                                               }
+                                       }
+                               }
+                       }
+
+                       if (firstException != null) {
+                               
ExceptionUtils.rethrowIOException(firstException);
+                       }
+               }
+       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
new file mode 100644
index 0000000..5237c62
--- /dev/null
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.tasks;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.runtime.blob.BlobKey;
+import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
+import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
+import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
+import org.apache.flink.runtime.execution.ExecutionState;
+import 
org.apache.flink.runtime.execution.librarycache.FallbackLibraryCacheManager;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.filecache.FileCache;
+import org.apache.flink.runtime.instance.ActorGateway;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
+import org.apache.flink.runtime.io.network.NetworkEnvironment;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.memory.MemoryManager;
+import 
org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
+import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.runtime.taskmanager.Task;
+import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
+import org.apache.flink.runtime.util.EnvironmentInformation;
+import org.apache.flink.runtime.util.SerializableObject;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.operators.StreamSource;
+import org.apache.flink.util.SerializedValue;
+
+import org.junit.Test;
+
+import scala.concurrent.duration.FiniteDuration;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.net.URL;
+import java.util.Collections;
+import java.util.concurrent.TimeUnit;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+/**
+ * This test checks that task restores that get stuck in the presence of 
interrupts
+ * are handled properly.
+ *
+ * In practice, reading from HDFS is interrupt sensitive: The HDFS code 
frequently deadlocks
+ * or livelocks if it is interrupted.
+ */
+public class InterruptSensitiveRestoreTest {
+
+       private static final OneShotLatch IN_RESTORE_LATCH = new OneShotLatch();
+
+       @Test
+       public void testRestoreWithInterrupt() throws Exception {
+
+               Configuration taskConfig = new Configuration();
+               StreamConfig cfg = new StreamConfig(taskConfig);
+               cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+               cfg.setStreamOperator(new StreamSource<>(new TestSource()));
+
+               StateHandle<Serializable> lockingHandle = new 
InterruptLockingStateHandle();
+               StreamTaskState opState = new StreamTaskState();
+               opState.setFunctionState(lockingHandle);
+               StreamTaskStateList taskState = new StreamTaskStateList(new 
StreamTaskState[] { opState });
+
+               TaskDeploymentDescriptor tdd = 
createTaskDeploymentDescriptor(taskConfig, taskState);
+               Task task = createTask(tdd);
+
+               // start the task and wait until it is in "restore"
+               task.startTaskThread();
+               IN_RESTORE_LATCH.await();
+
+               // trigger cancellation and signal to continue
+               task.cancelExecution();
+
+               task.getExecutingThread().join(30000);
+
+               if (task.getExecutionState() == ExecutionState.CANCELING) {
+                       fail("Task is stuck and not canceling");
+               }
+
+               assertEquals(ExecutionState.CANCELED, task.getExecutionState());
+               assertNull(task.getFailureCause());
+       }
+
+       // 
------------------------------------------------------------------------
+       //  Utilities
+       // 
------------------------------------------------------------------------
+
+       private static TaskDeploymentDescriptor createTaskDeploymentDescriptor(
+                       Configuration taskConfig,
+                       StateHandle<?> state) throws IOException {
+
+               return new TaskDeploymentDescriptor(
+                               new JobID(),
+                               "test job name",
+                               new JobVertexID(),
+                               new ExecutionAttemptID(),
+                               new SerializedValue<>(new ExecutionConfig()),
+                               "test task name",
+                               0, 1, 0,
+                               new Configuration(),
+                               taskConfig,
+                               SourceStreamTask.class.getName(),
+                               
Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
+                               
Collections.<InputGateDeploymentDescriptor>emptyList(),
+                               Collections.<BlobKey>emptyList(),
+                               Collections.<URL>emptyList(),
+                               0,
+                               new SerializedValue<StateHandle<?>>(state));
+       }
+       
+       private static Task createTask(TaskDeploymentDescriptor tdd) throws 
IOException {
+               return new Task(
+                               tdd,
+                               mock(MemoryManager.class),
+                               mock(IOManager.class),
+                               mock(NetworkEnvironment.class),
+                               mock(BroadcastVariableManager.class),
+                               mock(ActorGateway.class),
+                               mock(ActorGateway.class),
+                               new FiniteDuration(10, TimeUnit.SECONDS),
+                               new FallbackLibraryCacheManager(),
+                               new FileCache(new Configuration()),
+                               new TaskManagerRuntimeInfo(
+                                               "localhost", new 
Configuration(), EnvironmentInformation.getTemporaryFileDirectory()),
+                               new UnregisteredTaskMetricsGroup());
+               
+       }
+
+       // 
------------------------------------------------------------------------
+
+       @SuppressWarnings("serial")
+       private static class InterruptLockingStateHandle implements 
StateHandle<Serializable> {
+
+               private transient volatile boolean closed;
+               
+               @Override
+               public Serializable getState(ClassLoader userCodeClassLoader) {
+                       IN_RESTORE_LATCH.trigger();
+                       
+                       // this mimics what happens in the HDFS client code.
+                       // an interrupt on a waiting object leads to an 
infinite loop
+                       try {
+                               synchronized (this) {
+                                       wait();
+                               }
+                       }
+                       catch (InterruptedException e) {
+                               while (!closed) {
+                                       try {
+                                               synchronized (this) {
+                                                       wait();
+                                               }
+                                       } catch (InterruptedException ignored) 
{}
+                               }
+                       }
+                       
+                       return new SerializableObject();
+               }
+
+               @Override
+               public void discardState() throws Exception {}
+
+               @Override
+               public long getStateSize() throws Exception {
+                       return 0;
+               }
+
+               @Override
+               public void close() throws IOException {
+                       closed = true;
+               }
+       }
+
+       // 
------------------------------------------------------------------------
+       
+       private static class TestSource implements SourceFunction<Object>, 
Checkpointed<Serializable> {
+               private static final long serialVersionUID = 1L;
+
+               @Override
+               public void run(SourceContext<Object> ctx) throws Exception {
+                       fail("should never be called");
+               }
+
+               @Override
+               public void cancel() {}
+
+               @Override
+               public Serializable snapshotState(long checkpointId, long 
checkpointTimestamp) throws Exception {
+                       fail("should never be called");
+                       return null;
+               }
+
+               @Override
+               public void restoreState(Serializable state) throws Exception {
+                       fail("should never be called");
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e9f660d1/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskAsyncCheckpointTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskAsyncCheckpointTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskAsyncCheckpointTest.java
index b74903a..cfaeaad 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskAsyncCheckpointTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskAsyncCheckpointTest.java
@@ -37,6 +37,7 @@ import 
org.powermock.core.classloader.annotations.PowerMockIgnore;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
 
+import java.io.IOException;
 import java.lang.reflect.Field;
 
 import static org.junit.Assert.assertEquals;
@@ -196,6 +197,9 @@ public class StreamTaskAsyncCheckpointTest {
                public long getStateSize() {
                        return 0;
                }
+
+               @Override
+               public void close() throws IOException {}
        }
 
        private static class TestStateHandle implements StateHandle<String> {
@@ -214,13 +218,15 @@ public class StreamTaskAsyncCheckpointTest {
                }
 
                @Override
-               public void discardState() throws Exception {
-               }
+               public void discardState() throws Exception {}
 
                @Override
                public long getStateSize() {
                        return 0;
                }
+
+               @Override
+               public void close() throws IOException {}
        }
        
        public static class DummyMapFunction<T> implements MapFunction<T, T> {

Reply via email to