[FLINK-6014] [checkpoint] Allow the registration of state objects in checkpoints


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/218bed8b
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/218bed8b
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/218bed8b

Branch: refs/heads/master
Commit: 218bed8b8e49b0e4c61c61f696a8f010eafea1b7
Parents: db31ca3
Author: xiaogang.sxg <[email protected]>
Authored: Mon Mar 13 19:23:47 2017 +0800
Committer: Stefan Richter <[email protected]>
Committed: Sat Apr 22 15:25:56 2017 +0200

----------------------------------------------------------------------
 .../checkpoint/CheckpointCoordinator.java       | 136 ++++++++----
 .../runtime/checkpoint/CompletedCheckpoint.java |  73 +++++--
 .../checkpoint/CompletedCheckpointStore.java    |   9 +-
 .../runtime/checkpoint/PendingCheckpoint.java   |  16 +-
 .../StandaloneCompletedCheckpointStore.java     |  18 +-
 .../flink/runtime/checkpoint/SubtaskState.java  |  49 +++--
 .../flink/runtime/checkpoint/TaskState.java     |  30 ++-
 .../ZooKeeperCompletedCheckpointStore.java      |  67 ++++--
 .../runtime/state/CompositeStateHandle.java     |  60 ++++++
 .../flink/runtime/state/SharedStateHandle.java  |  39 ++++
 .../runtime/state/SharedStateRegistry.java      | 178 ++++++++++++++++
 .../flink/runtime/jobmanager/JobManager.scala   |  27 +--
 .../CheckpointCoordinatorFailureTest.java       |  26 ++-
 .../checkpoint/CheckpointCoordinatorTest.java   | 211 ++++++++++++++-----
 .../checkpoint/CheckpointStateRestoreTest.java  |   4 +-
 .../CompletedCheckpointStoreTest.java           |  78 +++++--
 .../checkpoint/CompletedCheckpointTest.java     |  24 ++-
 ...ExecutionGraphCheckpointCoordinatorTest.java |   9 +-
 .../checkpoint/PendingCheckpointTest.java       |  22 +-
 .../StandaloneCompletedCheckpointStoreTest.java |  27 ++-
 ...ZooKeeperCompletedCheckpointStoreITCase.java |  40 ++--
 .../ZooKeeperCompletedCheckpointStoreTest.java  |  11 +-
 .../jobmanager/JobManagerHARecoveryTest.java    |  65 +-----
 .../runtime/state/SharedStateRegistryTest.java  | 136 ++++++++++++
 .../RecoverableCompletedCheckpointStore.java    | 109 ++++++++++
 25 files changed, 1176 insertions(+), 288 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index 7087540..5309dd4 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.checkpoint;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.ConfigConstants;
+import org.apache.flink.runtime.checkpoint.savepoint.SavepointLoader;
 import org.apache.flink.runtime.checkpoint.savepoint.SavepointStore;
 import org.apache.flink.runtime.concurrent.ApplyFunction;
 import org.apache.flink.runtime.concurrent.Future;
@@ -36,10 +37,11 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
-import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.TaskStateHandles;
 
 import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
+import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -108,6 +110,9 @@ public class CheckpointCoordinator {
        /** Completed checkpoints. Implementations can be blocking. Make sure 
calls to methods
         * accessing this don't block the job manager actor and run 
asynchronously. */
        private final CompletedCheckpointStore completedCheckpointStore;
+       
+       /** Registry for shared states */
+       private final SharedStateRegistry sharedStateRegistry;
 
        /** Default directory for persistent checkpoints; <code>null</code> if 
none configured.
         * THIS WILL BE REPLACED BY PROPER STATE-BACKEND METADATA WRITING */
@@ -218,6 +223,7 @@ public class CheckpointCoordinator {
                this.completedCheckpointStore = 
checkNotNull(completedCheckpointStore);
                this.checkpointDirectory = checkpointDirectory;
                this.executor = checkNotNull(executor);
+               this.sharedStateRegistry = new SharedStateRegistry();
 
                this.recentPendingCheckpoints = new 
ArrayDeque<>(NUM_GHOST_CHECKPOINT_IDS);
 
@@ -282,7 +288,7 @@ public class CheckpointCoordinator {
                                }
                                pendingCheckpoints.clear();
 
-                               completedCheckpointStore.shutdown(jobStatus);
+                               completedCheckpointStore.shutdown(jobStatus, 
sharedStateRegistry);
                                checkpointIdCounter.shutdown(jobStatus);
                        }
                }
@@ -615,7 +621,7 @@ public class CheckpointCoordinator {
                        throw new IllegalArgumentException("Received 
DeclineCheckpoint message for job " +
                                message.getJob() + " while this coordinator 
handles job " + job);
                }
-
+               
                final long checkpointId = message.getCheckpointId();
                final String reason = (message.getReason() != null ? 
message.getReason().getMessage() : "");
 
@@ -695,7 +701,7 @@ public class CheckpointCoordinator {
                }
 
                final long checkpointId = message.getCheckpointId();
-
+               
                synchronized (lock) {
                        // we need to check inside the lock for being shutdown 
as well, otherwise we
                        // get races and invalid error log messages
@@ -778,49 +784,55 @@ public class CheckpointCoordinator {
         */
        private void completePendingCheckpoint(PendingCheckpoint 
pendingCheckpoint) throws CheckpointException {
                final long checkpointId = pendingCheckpoint.getCheckpointId();
-               CompletedCheckpoint completedCheckpoint = null;
+               final CompletedCheckpoint completedCheckpoint;
 
                try {
-                       // externalize the checkpoint if required
-                       if 
(pendingCheckpoint.getProps().externalizeCheckpoint()) {
-                               completedCheckpoint = 
pendingCheckpoint.finalizeCheckpointExternalized();
-                       } else {
-                               completedCheckpoint = 
pendingCheckpoint.finalizeCheckpointNonExternalized();
-                       }
-
-                       
completedCheckpointStore.addCheckpoint(completedCheckpoint);
-
-                       rememberRecentCheckpointId(checkpointId);
-                       dropSubsumedCheckpoints(checkpointId);
-               }
-               catch (Exception exception) {
-                       // abort the current pending checkpoint if it has not 
been discarded yet
-                       if (!pendingCheckpoint.isDiscarded()) {
-                               pendingCheckpoint.abortError(exception);
+                       try {
+                               // externalize the checkpoint if required
+                               if 
(pendingCheckpoint.getProps().externalizeCheckpoint()) {
+                                       completedCheckpoint = 
pendingCheckpoint.finalizeCheckpointExternalized();
+                               } else {
+                                       completedCheckpoint = 
pendingCheckpoint.finalizeCheckpointNonExternalized();
+                               }
+                       } catch (Exception e1) {
+                               // abort the current pending checkpoint if we 
fails to finalize the pending checkpoint.
+                               if (!pendingCheckpoint.isDiscarded()) {
+                                       pendingCheckpoint.abortError(e1);
+                               }
+       
+                               throw new CheckpointException("Could not 
finalize the pending checkpoint " + checkpointId + '.', e1);
                        }
-
-                       if (completedCheckpoint != null) {
+       
+                       // the pending checkpoint must be discarded after the 
finalization
+                       
Preconditions.checkState(pendingCheckpoint.isDiscarded() && completedCheckpoint 
!= null);
+       
+                       try {
+                               
completedCheckpointStore.addCheckpoint(completedCheckpoint, 
sharedStateRegistry);
+                       } catch (Exception exception) {
                                // we failed to store the completed checkpoint. 
Let's clean up
-                               final CompletedCheckpoint cc = 
completedCheckpoint;
-
                                executor.execute(new Runnable() {
                                        @Override
                                        public void run() {
                                                try {
-                                                       cc.discard();
+                                                       
completedCheckpoint.discardOnFail();
                                                } catch (Throwable t) {
-                                                       LOG.warn("Could not 
properly discard completed checkpoint {}.", cc.getCheckpointID(), t);
+                                                       LOG.warn("Could not 
properly discard completed checkpoint {}.", 
completedCheckpoint.getCheckpointID(), t);
                                                }
                                        }
                                });
+                               
+                               throw new CheckpointException("Could not 
complete the pending checkpoint " + checkpointId + '.', exception);
                        }
-
-                       throw new CheckpointException("Could not complete the 
pending checkpoint " + checkpointId + '.', exception);
                } finally {
                        pendingCheckpoints.remove(checkpointId);
 
                        triggerQueuedRequests();
                }
+               
+               rememberRecentCheckpointId(checkpointId);
+               
+               // drop those pending checkpoints that are at prior to the 
completed one
+               dropSubsumedCheckpoints(checkpointId);
 
                // record the time when this was completed, to calculate
                // the 'min delay between checkpoints'
@@ -941,7 +953,7 @@ public class CheckpointCoordinator {
                        }
 
                        // Recover the checkpoints
-                       completedCheckpointStore.recover();
+                       completedCheckpointStore.recover(sharedStateRegistry);
 
                        // restore from the latest checkpoint
                        CompletedCheckpoint latest = 
completedCheckpointStore.getLatestCheckpoint();
@@ -978,6 +990,44 @@ public class CheckpointCoordinator {
                }
        }
 
+       /**
+        * Restore the state with given savepoint
+        * 
+        * @param savepointPath    Location of the savepoint
+        * @param allowNonRestored True if allowing checkpoint state that 
cannot be 
+        *                         mapped to any job vertex in tasks.
+        * @param tasks            Map of job vertices to restore. State for 
these 
+        *                         vertices is restored via 
+        *                         {@link 
Execution#setInitialState(TaskStateHandles)}.
+        * @param userClassLoader  The class loader to resolve serialized 
classes in 
+        *                         legacy savepoint versions. 
+        */
+       public boolean restoreSavepoint(
+                       String savepointPath, 
+                       boolean allowNonRestored,
+                       Map<JobVertexID, ExecutionJobVertex> tasks,
+                       ClassLoader userClassLoader) throws Exception {
+               
+               Preconditions.checkNotNull(savepointPath, "The savepoint path 
cannot be null.");
+               
+               LOG.info("Starting job from savepoint {} ({})", 
+                               savepointPath, (allowNonRestored ? "allowing 
non restored state" : ""));
+
+               // Load the savepoint as a checkpoint into the system
+               CompletedCheckpoint savepoint = 
SavepointLoader.loadAndValidateSavepoint(
+                               job, tasks, savepointPath, userClassLoader, 
allowNonRestored);
+
+               completedCheckpointStore.addCheckpoint(savepoint, 
sharedStateRegistry);
+               
+               // Reset the checkpoint ID counter
+               long nextCheckpointId = savepoint.getCheckpointID() + 1;
+               checkpointIdCounter.setCount(nextCheckpointId);
+               
+               LOG.info("Reset the checkpoint ID to {}.", nextCheckpointId);
+               
+               return restoreLatestCheckpointedState(tasks, true, 
allowNonRestored);
+       }
+
        // 
--------------------------------------------------------------------------------------------
        //  Accessors
        // 
--------------------------------------------------------------------------------------------
@@ -1007,6 +1057,10 @@ public class CheckpointCoordinator {
        public CompletedCheckpointStore getCheckpointStore() {
                return completedCheckpointStore;
        }
+       
+       public SharedStateRegistry getSharedStateRegistry() {
+               return sharedStateRegistry;
+       }
 
        public CheckpointIDCounter getCheckpointIdCounter() {
                return checkpointIdCounter;
@@ -1095,24 +1149,30 @@ public class CheckpointCoordinator {
         * @param jobId identifying the job to which the state object belongs
         * @param executionAttemptID identifying the task to which the state 
object belongs
         * @param checkpointId of the state object
-        * @param stateObject to discard asynchronously
+        * @param subtaskState to discard asynchronously
         */
        private void discardState(
                        final JobID jobId,
                        final ExecutionAttemptID executionAttemptID,
                        final long checkpointId,
-                       final StateObject stateObject) {
+                       final SubtaskState subtaskState) {
 
-               if (stateObject != null) {
+               if (subtaskState != null) {
                        executor.execute(new Runnable() {
                                @Override
                                public void run() {
                                        try {
-                                               stateObject.discardState();
-                                       } catch (Throwable throwable) {
-                                       LOG.warn("Could not properly discard 
state object of checkpoint {} " +
-                                               "belonging to task {} of job 
{}.", checkpointId, executionAttemptID, jobId,
-                                               throwable);
+                                               
subtaskState.discardSharedStatesOnFail();
+                                       } catch (Throwable t1) {
+                                               LOG.warn("Could not properly 
discard shared states of checkpoint {} " +
+                                                       "belonging to task {} 
of job {}.", checkpointId, executionAttemptID, jobId, t1);
+                                       }
+
+                                       try {
+                                               subtaskState.discardState();
+                                       } catch (Throwable t2) {
+                                               LOG.warn("Could not properly 
discard state object of checkpoint {} " +
+                                                       "belonging to task {} 
of job {}.", checkpointId, executionAttemptID, jobId, t2);
                                        }
                                }
                        });

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
index 17ce4d5..58e91e1 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
@@ -23,10 +23,12 @@ import org.apache.flink.api.common.JobID;
 import 
org.apache.flink.runtime.checkpoint.CompletedCheckpointStats.DiscardCallback;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.ExceptionUtils;
 
+import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -184,22 +186,30 @@ public class CompletedCheckpoint implements Serializable {
                return props;
        }
 
-       public boolean subsume() throws Exception {
+       public void discardOnFail() throws Exception {
+               discard(null, true);
+       }
+
+       public boolean discardOnSubsume(SharedStateRegistry 
sharedStateRegistry) throws Exception {
+               Preconditions.checkNotNull(sharedStateRegistry, "The registry 
cannot be null.");
+
                if (props.discardOnSubsumed()) {
-                       discard();
+                       discard(sharedStateRegistry, false);
                        return true;
                }
 
                return false;
        }
 
-       public boolean discard(JobStatus jobStatus) throws Exception {
+       public boolean discardOnShutdown(JobStatus jobStatus, 
SharedStateRegistry sharedStateRegistry) throws Exception {
+               Preconditions.checkNotNull(sharedStateRegistry, "The registry 
cannot be null.");
+
                if (jobStatus == JobStatus.FINISHED && 
props.discardOnJobFinished() ||
                                jobStatus == JobStatus.CANCELED && 
props.discardOnJobCancelled() ||
                                jobStatus == JobStatus.FAILED && 
props.discardOnJobFailed() ||
                                jobStatus == JobStatus.SUSPENDED && 
props.discardOnJobSuspended()) {
 
-                       discard();
+                       discard(sharedStateRegistry, false);
                        return true;
                } else {
                        if (externalPointer != null) {
@@ -211,7 +221,10 @@ public class CompletedCheckpoint implements Serializable {
                }
        }
 
-       void discard() throws Exception {
+       private void discard(SharedStateRegistry sharedStateRegistry, boolean 
failed) throws Exception {
+               Preconditions.checkState(failed || (sharedStateRegistry != 
null),
+                       "The registry must not be null if the complete 
checkpoint does not fail.");
+
                try {
                        // collect exceptions and continue cleanup
                        Exception exception = null;
@@ -220,25 +233,31 @@ public class CompletedCheckpoint implements Serializable {
                        if (externalizedMetadata != null) {
                                try {
                                        externalizedMetadata.discardState();
-                               }
-                               catch (Exception e) {
+                               } catch (Exception e) {
                                        exception = e;
                                }
                        }
 
-                       // drop the actual state
+                       // In the cases where the completed checkpoint fails, 
the shared
+                       // states have not been registered to the registry. 
It's the state
+                       // handles' responsibility to discard their shared 
states.
+                       if (!failed) {
+                               unregisterSharedStates(sharedStateRegistry);
+                       } else {
+                               discardSharedStatesOnFail();
+                       }
+
+                       // discard private state objects
                        try {
                                
StateUtil.bestEffortDiscardAllStateObjects(taskStates.values());
-                       }
-                       catch (Exception e) {
+                       } catch (Exception e) {
                                exception = ExceptionUtils.firstOrSuppressed(e, 
exception);
                        }
 
                        if (exception != null) {
                                throw exception;
                        }
-               }
-               finally {
+               } finally {
                        taskStates.clear();
 
                        // to be null-pointer safe, copy reference to stack
@@ -290,6 +309,36 @@ public class CompletedCheckpoint implements Serializable {
                this.discardCallback = discardCallback;
        }
 
+       /**
+        * Register all shared states in the given registry. This is method is 
called
+        * when the completed checkpoint has been successfully added into the 
store.
+        *
+        * @param sharedStateRegistry The registry where shared states are 
registered
+        */
+       public void registerSharedStates(SharedStateRegistry 
sharedStateRegistry) {
+               sharedStateRegistry.registerAll(taskStates.values());
+       }
+
+       /**
+        * Unregister all shared states from the given registry. This is method 
is
+        * called when the completed checkpoint is subsumed or the job 
terminates.
+        *
+        * @param sharedStateRegistry The registry where shared states are 
registered
+        */
+       private void unregisterSharedStates(SharedStateRegistry 
sharedStateRegistry) {
+               sharedStateRegistry.unregisterAll(taskStates.values());
+       }
+
+       /**
+        * Discard all shared states created in the checkpoint. This method is 
called
+        * when the completed checkpoint fails to be added into the store.
+        */
+       private void discardSharedStatesOnFail() throws Exception {
+               for (TaskState taskState : taskStates.values()) {
+                       taskState.discardSharedStatesOnFail();
+               }
+       }
+
        // 
--------------------------------------------------------------------------------------------
 
        @Override

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
index 9c2b199..0ade25c 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.jobgraph.JobStatus;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 
 import java.util.List;
 
@@ -33,16 +34,16 @@ public interface CompletedCheckpointStore {
         * <p>After a call to this method, {@link #getLatestCheckpoint()} 
returns the latest
         * available checkpoint.
         */
-       void recover() throws Exception;
+       void recover(SharedStateRegistry sharedStateRegistry) throws Exception;
 
        /**
         * Adds a {@link CompletedCheckpoint} instance to the list of completed 
checkpoints.
         *
         * <p>Only a bounded number of checkpoints is kept. When exceeding the 
maximum number of
         * retained checkpoints, the oldest one will be discarded via {@link
-        * CompletedCheckpoint#discard()}.
+        * CompletedCheckpoint#discardOnSubsume(SharedStateRegistry)} )}.
         */
-       void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception;
+       void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry 
sharedStateRegistry) throws Exception;
 
        /**
         * Returns the latest {@link CompletedCheckpoint} instance or 
<code>null</code> if none was
@@ -58,7 +59,7 @@ public interface CompletedCheckpointStore {
         *
         * @param jobStatus Job state on shut down
         */
-       void shutdown(JobStatus jobStatus) throws Exception;
+       void shutdown(JobStatus jobStatus, SharedStateRegistry 
sharedStateRegistry) throws Exception;
 
        /**
         * Returns all {@link CompletedCheckpoint} instances.

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
index e1182ae..6805dea 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
@@ -497,6 +497,7 @@ public class PendingCheckpoint {
        }
 
        private void dispose(boolean releaseState) {
+
                synchronized (lock) {
                        try {
                                numAcknowledgedTasks = -1;
@@ -504,11 +505,22 @@ public class PendingCheckpoint {
                                        executor.execute(new Runnable() {
                                                @Override
                                                public void run() {
+
+                                                       // discard the shared 
states that are created in the checkpoint
+                                                       for (TaskState 
taskState : taskStates.values()) {
+                                                               try {
+                                                                       
taskState.discardSharedStatesOnFail();
+                                                               } catch 
(Throwable t) {
+                                                                       
LOG.warn("Could not properly dispose unreferenced shared states.");
+                                                               }
+                                                       }
+
+                                                       // discard the private 
states
                                                        try {
                                                                
StateUtil.bestEffortDiscardAllStateObjects(taskStates.values());
                                                        } catch (Throwable t) {
-                                                               LOG.warn("Could 
not properly dispose the pending checkpoint {} of job {}.", 
-                                                                               
checkpointId, jobId, t);
+                                                               LOG.warn("Could 
not properly dispose the private states in the pending checkpoint {} of job 
{}.",
+                                                                       
checkpointId, jobId, t);
                                                        } finally {
                                                                
taskStates.clear();
                                                        }

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
index 6eb5242..9f833c3 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobmanager.HighAvailabilityMode;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -56,16 +57,21 @@ public class StandaloneCompletedCheckpointStore implements 
CompletedCheckpointSt
        }
 
        @Override
-       public void recover() throws Exception {
+       public void recover(SharedStateRegistry sharedStateRegistry) throws 
Exception {
                // Nothing to do
        }
 
        @Override
-       public void addCheckpoint(CompletedCheckpoint checkpoint) throws 
Exception {
-               checkpoints.add(checkpoint);
+       public void addCheckpoint(CompletedCheckpoint checkpoint, 
SharedStateRegistry sharedStateRegistry) throws Exception {
+               
+               checkpoints.addLast(checkpoint);
+
+               checkpoint.registerSharedStates(sharedStateRegistry);
+
                if (checkpoints.size() > maxNumberOfCheckpointsToRetain) {
                        try {
-                               checkpoints.remove().subsume();
+                               CompletedCheckpoint checkpointToSubsume = 
checkpoints.removeFirst();
+                               
checkpointToSubsume.discardOnSubsume(sharedStateRegistry);
                        } catch (Exception e) {
                                LOG.warn("Fail to subsume the old checkpoint.", 
e);
                        }
@@ -93,12 +99,12 @@ public class StandaloneCompletedCheckpointStore implements 
CompletedCheckpointSt
        }
 
        @Override
-       public void shutdown(JobStatus jobStatus) throws Exception {
+       public void shutdown(JobStatus jobStatus, SharedStateRegistry 
sharedStateRegistry) throws Exception {
                try {
                        LOG.info("Shutting down");
 
                        for (CompletedCheckpoint checkpoint : checkpoints) {
-                               checkpoint.discard(jobStatus);
+                               checkpoint.discardOnShutdown(jobStatus, 
sharedStateRegistry);
                        }
                } finally {
                        checkpoints.clear();

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
index 97b08fc..e968643 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
@@ -19,11 +19,15 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CompositeStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.util.Arrays;
 
@@ -33,7 +37,9 @@ import static 
org.apache.flink.util.Preconditions.checkNotNull;
  * Container for the chained state of one parallel subtask of an 
operator/task. This is part of the
  * {@link TaskState}.
  */
-public class SubtaskState implements StateObject {
+public class SubtaskState implements CompositeStateHandle {
+
+       private static final Logger LOG = 
LoggerFactory.getLogger(SubtaskState.class);
 
        private static final long serialVersionUID = -2394696997971923995L;
 
@@ -130,19 +136,38 @@ public class SubtaskState implements StateObject {
        }
 
        @Override
-       public long getStateSize() {
-               return stateSize;
+       public void discardState() {
+               try {
+                       StateUtil.bestEffortDiscardAllStateObjects(
+                               Arrays.asList(
+                                       legacyOperatorState,
+                                       managedOperatorState,
+                                       rawOperatorState,
+                                       managedKeyedState,
+                                       rawKeyedState));
+               } catch (Exception e) {
+                       LOG.warn("Error while discarding operator states.", e);
+               }
        }
 
        @Override
-       public void discardState() throws Exception {
-               StateUtil.bestEffortDiscardAllStateObjects(
-                               Arrays.asList(
-                                               legacyOperatorState,
-                                               managedOperatorState,
-                                               rawOperatorState,
-                                               managedKeyedState,
-                                               rawKeyedState));
+       public void registerSharedStates(SharedStateRegistry 
sharedStateRegistry) {
+               // No shared states
+       }
+
+       @Override
+       public void unregisterSharedStates(SharedStateRegistry 
sharedStateRegistry) {
+               // No shared states
+       }
+
+       @Override
+       public void discardSharedStatesOnFail() {
+               // No shared states
+       }
+
+       @Override
+       public long getStateSize() {
+               return stateSize;
        }
 
        // 
--------------------------------------------------------------------------------------------
@@ -206,7 +231,7 @@ public class SubtaskState implements StateObject {
                                ", operatorStateFromBackend=" + 
managedOperatorState +
                                ", operatorStateFromStream=" + rawOperatorState 
+
                                ", keyedStateFromBackend=" + managedKeyedState +
-                               ", keyedStateHandleFromStream=" + rawKeyedState 
+
+                               ", keyedStateFromStream=" + rawKeyedState +
                                ", stateSize=" + stateSize +
                                '}';
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
index 76f1c51..19fe962 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
@@ -19,8 +19,8 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.state.StateObject;
-import org.apache.flink.runtime.state.StateUtil;
+import org.apache.flink.runtime.state.CompositeStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.util.Preconditions;
 
 import java.util.Collection;
@@ -35,7 +35,7 @@ import java.util.Objects;
  *
  * This class basically groups all non-partitioned state and key-group state 
belonging to the same job vertex together.
  */
-public class TaskState implements StateObject {
+public class TaskState implements CompositeStateHandle {
 
        private static final long serialVersionUID = -4845578005863201810L;
 
@@ -124,9 +124,31 @@ public class TaskState implements StateObject {
 
        @Override
        public void discardState() throws Exception {
-               
StateUtil.bestEffortDiscardAllStateObjects(subtaskStates.values());
+               for (SubtaskState subtaskState : subtaskStates.values()) {
+                       subtaskState.discardState();
+               }
+       }
+
+       @Override
+       public void registerSharedStates(SharedStateRegistry 
sharedStateRegistry) {
+               for (SubtaskState subtaskState : subtaskStates.values()) {
+                       subtaskState.registerSharedStates(sharedStateRegistry);
+               }
+       }
+
+       @Override
+       public void unregisterSharedStates(SharedStateRegistry 
sharedStateRegistry) {
+               for (SubtaskState subtaskState : subtaskStates.values()) {
+                       
subtaskState.unregisterSharedStates(sharedStateRegistry);
+               }
        }
 
+       @Override
+       public void discardSharedStatesOnFail() {
+               for (SubtaskState subtaskState : subtaskStates.values()) {
+                       subtaskState.discardSharedStatesOnFail();
+               }
+       }
 
        @Override
        public long getStateSize() {

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
index af7bcc4..07546ea 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
@@ -27,6 +27,7 @@ import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobmanager.HighAvailabilityMode;
 import org.apache.flink.runtime.state.RetrievableStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
 import org.apache.flink.runtime.zookeeper.ZooKeeperStateHandleStore;
 import org.apache.flink.util.FlinkException;
@@ -123,7 +124,7 @@ public class ZooKeeperCompletedCheckpointStore implements 
CompletedCheckpointSto
                this.checkpointsInZooKeeper = new 
ZooKeeperStateHandleStore<>(this.client, stateStorage, executor);
 
                this.checkpointStateHandles = new 
ArrayDeque<>(maxNumberOfCheckpointsToRetain + 1);
-
+               
                LOG.info("Initialized in '{}'.", checkpointsPath);
        }
 
@@ -140,7 +141,7 @@ public class ZooKeeperCompletedCheckpointStore implements 
CompletedCheckpointSto
         * that the history of checkpoints is consistent.
         */
        @Override
-       public void recover() throws Exception {
+       public void recover(SharedStateRegistry sharedStateRegistry) throws 
Exception {
                LOG.info("Recovering checkpoints from ZooKeeper.");
 
                // Clear local handles in order to prevent duplicates on
@@ -164,8 +165,24 @@ public class ZooKeeperCompletedCheckpointStore implements 
CompletedCheckpointSto
 
                LOG.info("Found {} checkpoints in ZooKeeper.", 
numberOfInitialCheckpoints);
 
-               for (Tuple2<RetrievableStateHandle<CompletedCheckpoint>, 
String> checkpoint : initialCheckpoints) {
-                       checkpointStateHandles.add(checkpoint);
+               for (Tuple2<RetrievableStateHandle<CompletedCheckpoint>, 
String> checkpointStateHandle : initialCheckpoints) {
+
+                       CompletedCheckpoint completedCheckpoint = null;
+
+                       try {
+                               completedCheckpoint = 
retrieveCompletedCheckpoint(checkpointStateHandle);
+                       } catch (Exception e) {
+                               LOG.warn("Could not retrieve checkpoint. 
Removing it from the completed " +
+                                       "checkpoint store.", e);
+
+                               // remove the checkpoint with broken state 
handle
+                               removeBrokenStateHandle(checkpointStateHandle);
+                       }
+
+                       if (completedCheckpoint != null) {
+                               
completedCheckpoint.registerSharedStates(sharedStateRegistry);
+                               
checkpointStateHandles.add(checkpointStateHandle);
+                       }
                }
        }
 
@@ -175,21 +192,24 @@ public class ZooKeeperCompletedCheckpointStore implements 
CompletedCheckpointSto
         * @param checkpoint Completed checkpoint to add.
         */
        @Override
-       public void addCheckpoint(CompletedCheckpoint checkpoint) throws 
Exception {
+       public void addCheckpoint(final CompletedCheckpoint checkpoint, final 
SharedStateRegistry sharedStateRegistry) throws Exception {
                checkNotNull(checkpoint, "Checkpoint");
+               
+               final String path = 
checkpointIdToPath(checkpoint.getCheckpointID());
+               final RetrievableStateHandle<CompletedCheckpoint> stateHandle;
 
                // First add the new one. If it fails, we don't want to loose 
existing data.
-               String path = checkpointIdToPath(checkpoint.getCheckpointID());
-
-               final RetrievableStateHandle<CompletedCheckpoint> stateHandle =
-                               checkpointsInZooKeeper.add(path, checkpoint);
+               stateHandle = checkpointsInZooKeeper.add(path, checkpoint);
 
                checkpointStateHandles.addLast(new Tuple2<>(stateHandle, path));
 
+               // Register all shared states in the checkpoint
+               checkpoint.registerSharedStates(sharedStateRegistry);
+
                // Everything worked, let's remove a previous checkpoint if 
necessary.
                while (checkpointStateHandles.size() > 
maxNumberOfCheckpointsToRetain) {
                        try {
-                               
removeSubsumed(checkpointStateHandles.removeFirst());
+                               
removeSubsumed(checkpointStateHandles.removeFirst(), sharedStateRegistry);
                        } catch (Exception e) {
                                LOG.warn("Failed to subsume the old 
checkpoint", e);
                        }
@@ -261,13 +281,13 @@ public class ZooKeeperCompletedCheckpointStore implements 
CompletedCheckpointSto
        }
 
        @Override
-       public void shutdown(JobStatus jobStatus) throws Exception {
+       public void shutdown(JobStatus jobStatus, SharedStateRegistry 
sharedStateRegistry) throws Exception {
                if (jobStatus.isGloballyTerminalState()) {
                        LOG.info("Shutting down");
 
                        for 
(Tuple2<RetrievableStateHandle<CompletedCheckpoint>, String> checkpoint : 
checkpointStateHandles) {
                                try {
-                                       removeShutdown(checkpoint, jobStatus);
+                                       removeShutdown(checkpoint, jobStatus, 
sharedStateRegistry);
                                } catch (Exception e) {
                                        LOG.error("Failed to discard 
checkpoint.", e);
                                }
@@ -289,11 +309,19 @@ public class ZooKeeperCompletedCheckpointStore implements 
CompletedCheckpointSto
 
        // 
------------------------------------------------------------------------
 
-       private void removeSubsumed(final 
Tuple2<RetrievableStateHandle<CompletedCheckpoint>, String> stateHandleAndPath) 
throws Exception {
+       private void removeSubsumed(
+               final Tuple2<RetrievableStateHandle<CompletedCheckpoint>, 
String> stateHandleAndPath,
+               final SharedStateRegistry sharedStateRegistry) throws Exception 
{
+               
                Callable<Void> action = new Callable<Void>() {
                        @Override
                        public Void call() throws Exception {
-                               stateHandleAndPath.f0.retrieveState().subsume();
+                               CompletedCheckpoint completedCheckpoint = 
retrieveCompletedCheckpoint(stateHandleAndPath);
+                               
+                               if (completedCheckpoint != null) {
+                                       
completedCheckpoint.discardOnSubsume(sharedStateRegistry);
+                               }
+
                                return null;
                        }
                };
@@ -303,13 +331,18 @@ public class ZooKeeperCompletedCheckpointStore implements 
CompletedCheckpointSto
 
        private void removeShutdown(
                        final 
Tuple2<RetrievableStateHandle<CompletedCheckpoint>, String> stateHandleAndPath,
-                       final JobStatus jobStatus) throws Exception {
+                       final JobStatus jobStatus,
+                       final SharedStateRegistry sharedStateRegistry) throws 
Exception {
 
                Callable<Void> action = new Callable<Void>() {
                        @Override
                        public Void call() throws Exception {
-                               CompletedCheckpoint checkpoint = 
stateHandleAndPath.f0.retrieveState();
-                               checkpoint.discard(jobStatus);
+                               CompletedCheckpoint completedCheckpoint = 
retrieveCompletedCheckpoint(stateHandleAndPath);
+                               
+                               if (completedCheckpoint != null) {
+                                       
completedCheckpoint.discardOnShutdown(jobStatus, sharedStateRegistry);
+                               }
+
                                return null;
                        }
                };

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
new file mode 100644
index 0000000..2ea5bc9
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+/**
+ * Base of all snapshots that are taken by {@link StateBackend}s and some other
+ * components in tasks.
+ *
+ * <p>Each snapshot is composed of a collection of {@link StateObject}s some 
of 
+ * which may be referenced by other checkpoints. The shared states will be 
+ * registered at the given {@link SharedStateRegistry} when the handle is
+ * received by the {@link 
org.apache.flink.runtime.checkpoint.CheckpointCoordinator}
+ * and will be discarded when the checkpoint is discarded.
+ * 
+ * <p>The {@link SharedStateRegistry} is responsible for the discarding of the
+ * shared states. The composite state handle should only delete those private
+ * states in the {@link StateObject#discardState()} method.
+ */
+public interface CompositeStateHandle extends StateObject {
+
+       /**
+        * Register both created and referenced shared states in the given
+        * {@link SharedStateRegistry}. This method is called when the 
checkpoint
+        * successfully completes or is recovered from failures.
+        *
+        * @param stateRegistry The registry where shared states are registered.
+        */
+       void registerSharedStates(SharedStateRegistry stateRegistry);
+
+       /**
+        * Unregister both created and referenced shared states in the given
+        * {@link SharedStateRegistry}. This method is called when the 
checkpoint is
+        * subsumed or the job is shut down.
+        *
+        * @param stateRegistry The registry where shared states are registered.
+        */
+       void unregisterSharedStates(SharedStateRegistry stateRegistry);
+
+       /**
+        * Discard all shared states created in this checkpoint. This method is
+        * called when the checkpoint fails to complete.
+        */
+       void discardSharedStatesOnFail() throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
new file mode 100644
index 0000000..f856052
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+/**
+ * A handle to those states that are referenced by different checkpoints.
+ *
+ * <p> Each shared state handle is identified by a unique key. Two shared 
states
+ * are considered equal if their keys are identical.
+ *
+ * <p> All shared states are registered at the {@link SharedStateRegistry} once
+ * they are received by the {@link 
org.apache.flink.runtime.checkpoint.CheckpointCoordinator}
+ * and will be unregistered when the checkpoints are discarded. A shared state
+ * will be discarded once it is not referenced by any checkpoint. A shared 
state
+ * should not be referenced any more if it has been discarded.
+ */
+public interface SharedStateHandle extends StateObject {
+
+       /**
+        * Return the identifier of the shared state.
+        */
+       String getKey();
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
new file mode 100644
index 0000000..b5048d0
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.util.Preconditions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A {@code SharedStateRegistry} will be deployed in the 
+ * {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator} to 
+ * maintain the reference count of {@link SharedStateHandle}s which are shared
+ * among different checkpoints.
+ */
+public class SharedStateRegistry implements Serializable {
+
+       private static Logger LOG = 
LoggerFactory.getLogger(SharedStateRegistry.class);
+
+       private static final long serialVersionUID = -8357254413007773970L;
+
+       /** All registered state objects */
+       private final Map<String, SharedStateEntry> registeredStates = new 
HashMap<>();
+
+       /**
+        * Register the state in the registry
+        *
+        * @param state The state to register
+        * @param isNew True if the shared state is newly created
+        */
+       public void register(SharedStateHandle state, boolean isNew) {
+               if (state == null) {
+                       return;
+               }
+
+               synchronized (registeredStates) {
+                       SharedStateEntry entry = 
registeredStates.get(state.getKey());
+
+                       if (isNew) {
+                               Preconditions.checkState(entry == null,
+                                       "The state cannot be created more than 
once.");
+
+                               registeredStates.put(state.getKey(), new 
SharedStateEntry(state));
+                       } else {
+                               Preconditions.checkState(entry != null,
+                                       "The state cannot be referenced if it 
has not been created yet.");
+
+                               entry.increaseReferenceCount();
+                       }
+               }
+       }
+
+       /**
+        * Unregister the state in the registry
+        *
+        * @param state The state to unregister
+        */
+       public void unregister(SharedStateHandle state) {
+               if (state == null) {
+                       return;
+               }
+
+               synchronized (registeredStates) {
+                       SharedStateEntry entry = 
registeredStates.get(state.getKey());
+
+                       if (entry == null) {
+                               throw new IllegalStateException("Cannot 
unregister an unexisted state.");
+                       }
+
+                       entry.decreaseReferenceCount();
+
+                       // Remove the state from the registry when it's not 
referenced any more.
+                       if (entry.getReferenceCount() == 0) {
+                               registeredStates.remove(state.getKey());
+
+                               try {
+                                       entry.getState().discardState();
+                               } catch (Exception e) {
+                                       LOG.warn("Cannot properly discard the 
state " + entry.getState() + ".", e);
+                               }
+                       }
+               }
+       }
+
+       /**
+        * Register given shared states in the registry.
+        *
+        * @param stateHandles The shared states to register.
+        */
+       public void registerAll(Collection<? extends CompositeStateHandle> 
stateHandles) {
+               if (stateHandles == null) {
+                       return;
+               }
+
+               synchronized (registeredStates) {
+                       for (CompositeStateHandle stateHandle : stateHandles) {
+                               stateHandle.registerSharedStates(this);
+                       }
+               }
+       }
+
+
+
+       /**
+        * Unregister all the shared states referenced by the given.
+        *
+        * @param stateHandles The shared states to unregister.
+        */
+       public void unregisterAll(Collection<? extends CompositeStateHandle> 
stateHandles) {
+               if (stateHandles == null) {
+                       return;
+               }
+
+               synchronized (registeredStates) {
+                       for (CompositeStateHandle stateHandle : stateHandles) {
+                               stateHandle.unregisterSharedStates(this);
+                       }
+               }
+       }
+
+       private static class SharedStateEntry {
+               /** The shared object */
+               private final SharedStateHandle state;
+
+               /** The reference count of the object */
+               private int referenceCount;
+
+               SharedStateEntry(SharedStateHandle value) {
+                       this.state = value;
+                       this.referenceCount = 1;
+               }
+
+               SharedStateHandle getState() {
+                       return state;
+               }
+
+               int getReferenceCount() {
+                       return referenceCount;
+               }
+
+               void increaseReferenceCount() {
+                       ++referenceCount;
+               }
+
+               void decreaseReferenceCount() {
+                       --referenceCount;
+               }
+       }
+
+
+       @VisibleForTesting
+       public int getReferenceCount(SharedStateHandle state) {
+               SharedStateEntry entry = registeredStates.get(state.getKey());
+
+               return entry == null ? 0 : entry.getReferenceCount();
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
 
b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
index f2ecde5..40e2c2a 100644
--- 
a/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
+++ 
b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
@@ -1365,27 +1365,12 @@ class JobManager(
                 val savepointPath = savepointSettings.getRestorePath()
                 val allowNonRestored = 
savepointSettings.allowNonRestoredState()
 
-                log.info(s"Starting job from savepoint '$savepointPath'" +
-                  (if (allowNonRestored) " (allowing non restored state)" else 
"") + ".")
-
-                  // load the savepoint as a checkpoint into the system
-                  val savepoint: CompletedCheckpoint = 
SavepointLoader.loadAndValidateSavepoint(
-                    jobId,
-                    executionGraph.getAllVertices,
-                    savepointPath,
-                    executionGraph.getUserClassLoader,
-                    allowNonRestored)
-
-                executionGraph.getCheckpointCoordinator.getCheckpointStore
-                  .addCheckpoint(savepoint)
-
-                // Reset the checkpoint ID counter
-                val nextCheckpointId: Long = savepoint.getCheckpointID + 1
-                log.info(s"Reset the checkpoint ID to $nextCheckpointId")
-                executionGraph.getCheckpointCoordinator.getCheckpointIdCounter
-                  .setCount(nextCheckpointId)
-
-                executionGraph.restoreLatestCheckpointedState(true, 
allowNonRestored)
+                executionGraph.getCheckpointCoordinator.restoreSavepoint(
+                  savepointPath, 
+                  allowNonRestored,
+                  executionGraph.getAllVertices,
+                  executionGraph.getUserClassLoader
+                )
               } catch {
                 case e: Exception =>
                   jobInfo.notifyClients(

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
index 340e2a7..632f2c0 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
@@ -25,6 +25,7 @@ import 
org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.util.TestLogger;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -38,7 +39,9 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 
 @RunWith(PowerMockRunner.class)
@@ -83,12 +86,13 @@ public class CheckpointCoordinatorFailureTest extends 
TestLogger {
                assertFalse(pendingCheckpoint.isDiscarded());
 
                final long checkpointId = 
coord.getPendingCheckpoints().keySet().iterator().next();
-
-               AcknowledgeCheckpoint acknowledgeMessage = new 
AcknowledgeCheckpoint(jid, executionAttemptId, checkpointId);
-
-               CompletedCheckpoint completedCheckpoint = 
mock(CompletedCheckpoint.class);
-               
PowerMockito.whenNew(CompletedCheckpoint.class).withAnyArguments().thenReturn(completedCheckpoint);
-
+               
+               SubtaskState subtaskState = mock(SubtaskState.class);
+               
PowerMockito.when(subtaskState.getLegacyOperatorState()).thenReturn(null);
+               
PowerMockito.when(subtaskState.getManagedOperatorState()).thenReturn(null);
+               
+               AcknowledgeCheckpoint acknowledgeMessage = new 
AcknowledgeCheckpoint(jid, executionAttemptId, checkpointId, new 
CheckpointMetrics(), subtaskState);
+               
                try {
                        coord.receiveAcknowledgeMessage(acknowledgeMessage);
                        fail("Expected a checkpoint exception because the 
completed checkpoint store could not " +
@@ -100,18 +104,20 @@ public class CheckpointCoordinatorFailureTest extends 
TestLogger {
                // make sure that the pending checkpoint has been discarded 
after we could not complete it
                assertTrue(pendingCheckpoint.isDiscarded());
 
-               verify(completedCheckpoint).discard();
+               // make sure that the subtask state has been discarded after we 
could not complete it.
+               verify(subtaskState, times(1)).discardSharedStatesOnFail();
+               verify(subtaskState).discardState();
        }
 
        private static final class FailingCompletedCheckpointStore implements 
CompletedCheckpointStore {
 
                @Override
-               public void recover() throws Exception {
+               public void recover(SharedStateRegistry sharedStateRegistry) 
throws Exception {
                        throw new UnsupportedOperationException("Not 
implemented.");
                }
 
                @Override
-               public void addCheckpoint(CompletedCheckpoint checkpoint) 
throws Exception {
+               public void addCheckpoint(CompletedCheckpoint checkpoint, 
SharedStateRegistry sharedStateRegistry) throws Exception {
                        throw new Exception("The failing completed checkpoint 
store failed again... :-(");
                }
 
@@ -121,7 +127,7 @@ public class CheckpointCoordinatorFailureTest extends 
TestLogger {
                }
 
                @Override
-               public void shutdown(JobStatus jobStatus) throws Exception {
+               public void shutdown(JobStatus jobStatus, SharedStateRegistry 
sharedStateRegistry) throws Exception {
                        throw new UnsupportedOperationException("Not 
implemented.");
                }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 117c70d..fabf3fc 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -43,11 +43,13 @@ import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.testutils.CommonTestUtils;
+import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
 import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
@@ -86,13 +88,17 @@ import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyLong;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.withSettings;
 
 /**
  * Tests for the checkpoint coordinator.
@@ -545,19 +551,24 @@ public class CheckpointCoordinatorTest {
                        }
 
                        // acknowledge from one of the tasks
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, attemptID2, checkpointId));
+                       SubtaskState subtaskState2 = mock(SubtaskState.class);
+                       AcknowledgeCheckpoint acknowledgeCheckpoint1 = new 
AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), 
subtaskState2);
+                       coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1);
                        assertEquals(1, 
checkpoint.getNumberOfAcknowledgedTasks());
                        assertEquals(1, 
checkpoint.getNumberOfNonAcknowledgedTasks());
                        assertFalse(checkpoint.isDiscarded());
                        assertFalse(checkpoint.isFullyAcknowledged());
+                       verify(subtaskState2, 
never()).registerSharedStates(any(SharedStateRegistry.class));
 
                        // acknowledge the same task again (should not matter)
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, attemptID2, checkpointId));
+                       coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1);
                        assertFalse(checkpoint.isDiscarded());
                        assertFalse(checkpoint.isFullyAcknowledged());
+                       verify(subtaskState2, 
never()).registerSharedStates(any(SharedStateRegistry.class));
 
                        // acknowledge the other task.
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, attemptID1, checkpointId));
+                       SubtaskState subtaskState1 = mock(SubtaskState.class);
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), 
subtaskState1));
 
                        // the checkpoint is internally converted to a 
successful checkpoint and the
                        // pending checkpoint object is disposed
@@ -570,6 +581,12 @@ public class CheckpointCoordinatorTest {
                        // the canceler should be removed now
                        assertEquals(0, coord.getNumScheduledTasks());
 
+                       // validate that the subtasks states have registered 
their shared states.
+                       {
+                               verify(subtaskState1, 
times(1)).registerSharedStates(any(SharedStateRegistry.class));
+                               verify(subtaskState2, 
times(1)).registerSharedStates(any(SharedStateRegistry.class));
+                       }
+
                        // validate that the relevant tasks got a confirmation 
message
                        {
                                verify(vertex1.getCurrentExecutionAttempt(), 
times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), 
any(CheckpointOptions.class));
@@ -580,7 +597,7 @@ public class CheckpointCoordinatorTest {
                        assertEquals(jid, success.getJobId());
                        assertEquals(timestamp, success.getTimestamp());
                        assertEquals(checkpoint.getCheckpointId(), 
success.getCheckpointID());
-                       assertTrue(success.getTaskStates().isEmpty());
+                       assertEquals(2, success.getTaskStates().size());
 
                        // ---------------
                        // trigger another checkpoint and see that this one 
replaces the other checkpoint
@@ -602,6 +619,12 @@ public class CheckpointCoordinatorTest {
                        assertEquals(checkpointIdNew, 
successNew.getCheckpointID());
                        assertTrue(successNew.getTaskStates().isEmpty());
 
+                       // validate that the subtask states in old savepoint 
have unregister their shared states
+                       {
+                               verify(subtaskState1, 
times(1)).unregisterSharedStates(any(SharedStateRegistry.class));
+                               verify(subtaskState2, 
times(1)).unregisterSharedStates(any(SharedStateRegistry.class));
+                       }
+
                        // validate that the relevant tasks got a confirmation 
message
                        {
                                verify(vertex1.getCurrentExecutionAttempt(), 
times(1)).triggerCheckpoint(eq(checkpointIdNew), eq(timestampNew), 
any(CheckpointOptions.class));
@@ -678,8 +701,6 @@ public class CheckpointCoordinatorTest {
                        verify(triggerVertex1.getCurrentExecutionAttempt(), 
times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), 
any(CheckpointOptions.class));
                        verify(triggerVertex2.getCurrentExecutionAttempt(), 
times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), 
any(CheckpointOptions.class));
 
-                       CheckpointMetaData checkpointMetaData1 = new 
CheckpointMetaData(checkpointId1, 0L);
-
                        // acknowledge one of the three tasks
                        coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1));
 
@@ -699,8 +720,6 @@ public class CheckpointCoordinatorTest {
                        }
                        long checkpointId2 = pending2.getCheckpointId();
 
-                       CheckpointMetaData checkpointMetaData2 = new 
CheckpointMetaData(checkpointId2, 0L);
-
                        // trigger messages should have been sent
                        verify(triggerVertex1.getCurrentExecutionAttempt(), 
times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), 
any(CheckpointOptions.class));
                        verify(triggerVertex2.getCurrentExecutionAttempt(), 
times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), 
any(CheckpointOptions.class));
@@ -812,10 +831,9 @@ public class CheckpointCoordinatorTest {
                        verify(triggerVertex1.getCurrentExecutionAttempt(), 
times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), 
any(CheckpointOptions.class));
                        verify(triggerVertex2.getCurrentExecutionAttempt(), 
times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), 
any(CheckpointOptions.class));
 
-                       CheckpointMetaData checkpointMetaData1 = new 
CheckpointMetaData(checkpointId1, 0L);
-
                        // acknowledge one of the three tasks
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1));
+                       SubtaskState subtaskState1_2 = mock(SubtaskState.class);
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new 
CheckpointMetrics(), subtaskState1_2));
 
                        // start the second checkpoint
                        // trigger the first checkpoint. this should succeed
@@ -839,12 +857,18 @@ public class CheckpointCoordinatorTest {
 
                        // we acknowledge one more task from the first 
checkpoint and the second
                        // checkpoint completely. The second checkpoint should 
then subsume the first checkpoint
-                       CheckpointMetaData checkpointMetaData2= new 
CheckpointMetaData(checkpointId2, 0L);
 
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2));
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2));
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1));
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2));
+                       SubtaskState subtaskState2_3 = mock(SubtaskState.class);
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new 
CheckpointMetrics(), subtaskState2_3));
+
+                       SubtaskState subtaskState2_1 = mock(SubtaskState.class);
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new 
CheckpointMetrics(), subtaskState2_1));
+
+                       SubtaskState subtaskState1_1 = mock(SubtaskState.class);
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new 
CheckpointMetrics(), subtaskState1_1));
+
+                       SubtaskState subtaskState2_2 = mock(SubtaskState.class);
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new 
CheckpointMetrics(), subtaskState2_2));
 
                        // now, the second checkpoint should be confirmed, and 
the first discarded
                        // actually both pending checkpoints are discarded, and 
the second has been transformed
@@ -855,21 +879,47 @@ public class CheckpointCoordinatorTest {
                        assertEquals(0, coord.getNumberOfPendingCheckpoints());
                        assertEquals(1, 
coord.getNumberOfRetainedSuccessfulCheckpoints());
 
+                       // validate that all received subtask states in the 
first checkpoint have been discarded
+                       verify(subtaskState1_1, 
times(1)).discardSharedStatesOnFail();
+                       verify(subtaskState1_2, 
times(1)).discardSharedStatesOnFail();
+                       verify(subtaskState1_1, times(1)).discardState();
+                       verify(subtaskState1_2, times(1)).discardState();
+
+                       // validate that all subtask states in the second 
checkpoint are not discarded
+                       verify(subtaskState2_1, 
never()).unregisterSharedStates(any(SharedStateRegistry.class));
+                       verify(subtaskState2_2, 
never()).unregisterSharedStates(any(SharedStateRegistry.class));
+                       verify(subtaskState2_3, 
never()).unregisterSharedStates(any(SharedStateRegistry.class));
+                       verify(subtaskState2_1, never()).discardState();
+                       verify(subtaskState2_2, never()).discardState();
+                       verify(subtaskState2_3, never()).discardState();
+
                        // validate the committed checkpoints
                        List<CompletedCheckpoint> scs = 
coord.getSuccessfulCheckpoints();
                        CompletedCheckpoint success = scs.get(0);
                        assertEquals(checkpointId2, success.getCheckpointID());
                        assertEquals(timestamp2, success.getTimestamp());
                        assertEquals(jid, success.getJobId());
-                       assertTrue(success.getTaskStates().isEmpty());
+                       assertEquals(3, success.getTaskStates().size());
 
                        // the first confirm message should be out
                        verify(commitVertex.getCurrentExecutionAttempt(), 
times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2));
 
                        // send the last remaining ack for the first 
checkpoint. This should not do anything
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1));
+                       SubtaskState subtaskState1_3 = mock(SubtaskState.class);
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new 
CheckpointMetrics(), subtaskState1_3));
+                       verify(subtaskState1_3, 
times(1)).discardSharedStatesOnFail();
+                       verify(subtaskState1_3, times(1)).discardState();
 
                        coord.shutdown(JobStatus.FINISHED);
+
+                       // validate that the states in the second checkpoint 
have been discarded
+                       verify(subtaskState2_1, 
times(1)).unregisterSharedStates(any(SharedStateRegistry.class));
+                       verify(subtaskState2_2, 
times(1)).unregisterSharedStates(any(SharedStateRegistry.class));
+                       verify(subtaskState2_3, 
times(1)).unregisterSharedStates(any(SharedStateRegistry.class));
+                       verify(subtaskState2_1, times(1)).discardState();
+                       verify(subtaskState2_2, times(1)).discardState();
+                       verify(subtaskState2_3, times(1)).discardState();
+
                }
                catch (Exception e) {
                        e.printStackTrace();
@@ -924,7 +974,8 @@ public class CheckpointCoordinatorTest {
                        PendingCheckpoint checkpoint = 
coord.getPendingCheckpoints().values().iterator().next();
                        assertFalse(checkpoint.isDiscarded());
 
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId()));
+                       SubtaskState subtaskState = mock(SubtaskState.class);
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new 
CheckpointMetrics(), subtaskState));
 
                        // wait until the checkpoint must have expired.
                        // we check every 250 msecs conservatively for 5 seconds
@@ -941,6 +992,10 @@ public class CheckpointCoordinatorTest {
                        assertEquals(0, coord.getNumberOfPendingCheckpoints());
                        assertEquals(0, 
coord.getNumberOfRetainedSuccessfulCheckpoints());
 
+                       // validate that the received states have been discarded
+                       verify(subtaskState, 
times(1)).discardSharedStatesOnFail();
+                       verify(subtaskState, times(1)).discardState();
+
                        // no confirm message must have been sent
                        verify(commitVertex.getCurrentExecutionAttempt(), 
times(0)).notifyCheckpointComplete(anyLong(), anyLong());
 
@@ -993,8 +1048,6 @@ public class CheckpointCoordinatorTest {
                        // of the vertices that need to be acknowledged.
                        // non of the messages should throw an exception
 
-                       CheckpointMetaData checkpointMetaData = new 
CheckpointMetaData(checkpointId, 0L);
-
                        // wrong job id
                        coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(new JobID(), ackAttemptID1, checkpointId));
 
@@ -1058,19 +1111,22 @@ public class CheckpointCoordinatorTest {
 
                long checkpointId = pendingCheckpoint.getCheckpointId();
 
-               CheckpointMetaData checkpointMetaData = new 
CheckpointMetaData(checkpointId, 0L);
-
                SubtaskState triggerSubtaskState = mock(SubtaskState.class);
 
                // acknowledge the first trigger vertex
                coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new 
CheckpointMetrics(), triggerSubtaskState));
 
+               // verify that the subtask state has registered its shared 
states at the registry
+               verify(triggerSubtaskState, 
never()).discardSharedStatesOnFail();
+               verify(triggerSubtaskState, never()).discardState();
+
                SubtaskState unknownSubtaskState = mock(SubtaskState.class);
 
                // receive an acknowledge message for an unknown vertex
                coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new 
CheckpointMetrics(), unknownSubtaskState));
 
                // we should discard acknowledge messages from an unknown 
vertex belonging to our job
+               verify(unknownSubtaskState, 
times(1)).discardSharedStatesOnFail();
                verify(unknownSubtaskState, times(1)).discardState();
 
                SubtaskState differentJobSubtaskState = 
mock(SubtaskState.class);
@@ -1079,20 +1135,25 @@ public class CheckpointCoordinatorTest {
                coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new 
JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), 
differentJobSubtaskState));
 
                // we should not interfere with different jobs
+               verify(differentJobSubtaskState, 
never()).discardSharedStatesOnFail();
                verify(differentJobSubtaskState, never()).discardState();
 
                // duplicate acknowledge message for the trigger vertex
+               reset(triggerSubtaskState);
                coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new 
CheckpointMetrics(), triggerSubtaskState));
 
                // duplicate acknowledge messages for a known vertex should not 
trigger discarding the state
+               verify(triggerSubtaskState, 
never()).discardSharedStatesOnFail();
                verify(triggerSubtaskState, never()).discardState();
 
                // let the checkpoint fail at the first ack vertex
+               reset(triggerSubtaskState);
                coord.receiveDeclineMessage(new DeclineCheckpoint(jobId, 
ackAttemptId1, checkpointId));
 
                assertTrue(pendingCheckpoint.isDiscarded());
 
                // check that we've cleaned up the already acknowledged state
+               verify(triggerSubtaskState, 
times(1)).discardSharedStatesOnFail();
                verify(triggerSubtaskState, times(1)).discardState();
 
                SubtaskState ackSubtaskState = mock(SubtaskState.class);
@@ -1101,12 +1162,15 @@ public class CheckpointCoordinatorTest {
                coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, ackAttemptId2, checkpointId, new 
CheckpointMetrics(), ackSubtaskState));
 
                // check that we also cleaned up this state
+               verify(ackSubtaskState, times(1)).discardSharedStatesOnFail();
                verify(ackSubtaskState, times(1)).discardState();
 
                // receive an acknowledge message from an unknown job
+               reset(differentJobSubtaskState);
                coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new 
JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), 
differentJobSubtaskState));
 
                // we should not interfere with different jobs
+               verify(differentJobSubtaskState, 
never()).discardSharedStatesOnFail();
                verify(differentJobSubtaskState, never()).discardState();
 
                SubtaskState unknownSubtaskState2 = mock(SubtaskState.class);
@@ -1115,6 +1179,7 @@ public class CheckpointCoordinatorTest {
                coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new 
CheckpointMetrics(), unknownSubtaskState2));
 
                // we should discard acknowledge messages from an unknown 
vertex belonging to our job
+               verify(unknownSubtaskState2, 
times(1)).discardSharedStatesOnFail();
                verify(unknownSubtaskState2, times(1)).discardState();
        }
 
@@ -1363,12 +1428,11 @@ public class CheckpointCoordinatorTest {
                assertFalse(pending.isDiscarded());
                assertFalse(pending.isFullyAcknowledged());
                assertFalse(pending.canBeSubsumed());
-               assertTrue(pending instanceof PendingCheckpoint);
-
-               CheckpointMetaData checkpointMetaData = new 
CheckpointMetaData(checkpointId, 0L);
 
                // acknowledge from one of the tasks
-               coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, 
attemptID2, checkpointId));
+               SubtaskState subtaskState2 = mock(SubtaskState.class);
+               AcknowledgeCheckpoint acknowledgeCheckpoint2 = new 
AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), 
subtaskState2);
+               coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2);
                assertEquals(1, pending.getNumberOfAcknowledgedTasks());
                assertEquals(1, pending.getNumberOfNonAcknowledgedTasks());
                assertFalse(pending.isDiscarded());
@@ -1376,13 +1440,14 @@ public class CheckpointCoordinatorTest {
                assertFalse(savepointFuture.isDone());
 
                // acknowledge the same task again (should not matter)
-               coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, 
attemptID2, checkpointId));
+               coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2);
                assertFalse(pending.isDiscarded());
                assertFalse(pending.isFullyAcknowledged());
                assertFalse(savepointFuture.isDone());
 
                // acknowledge the other task.
-               coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, 
attemptID1, checkpointId));
+               SubtaskState subtaskState1 = mock(SubtaskState.class);
+               coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, 
attemptID1, checkpointId, new CheckpointMetrics(), subtaskState1));
 
                // the checkpoint is internally converted to a successful 
checkpoint and the
                // pending checkpoint object is disposed
@@ -1399,11 +1464,17 @@ public class CheckpointCoordinatorTest {
                        verify(vertex2.getCurrentExecutionAttempt(), 
times(1)).notifyCheckpointComplete(eq(checkpointId), eq(timestamp));
                }
 
+               // validate that the shared states are registered
+               {
+                       verify(subtaskState1, 
times(1)).registerSharedStates(any(SharedStateRegistry.class));
+                       verify(subtaskState2, 
times(1)).registerSharedStates(any(SharedStateRegistry.class));
+               }
+
                CompletedCheckpoint success = 
coord.getSuccessfulCheckpoints().get(0);
                assertEquals(jid, success.getJobId());
                assertEquals(timestamp, success.getTimestamp());
                assertEquals(pending.getCheckpointId(), 
success.getCheckpointID());
-               assertTrue(success.getTaskStates().isEmpty());
+               assertEquals(2, success.getTaskStates().size());
 
                // ---------------
                // trigger another checkpoint and see that this one replaces 
the other checkpoint
@@ -1426,6 +1497,14 @@ public class CheckpointCoordinatorTest {
                assertTrue(successNew.getTaskStates().isEmpty());
                assertTrue(savepointFuture.isDone());
 
+               // validate that the first savepoint does not discard its 
private states.
+               verify(subtaskState1, never()).discardState();
+               verify(subtaskState2, never()).discardState();
+
+               // Savepoints are not supposed to have any shared state.
+               verify(subtaskState1, 
never()).unregisterSharedStates(any(SharedStateRegistry.class));
+               verify(subtaskState2, 
never()).unregisterSharedStates(any(SharedStateRegistry.class));
+
                // validate that the relevant tasks got a confirmation message
                {
                        verify(vertex1.getCurrentExecutionAttempt(), 
times(1)).triggerCheckpoint(eq(checkpointIdNew), eq(timestampNew), 
any(CheckpointOptions.class));
@@ -1478,7 +1557,6 @@ public class CheckpointCoordinatorTest {
                // Trigger savepoint and checkpoint
                Future<CompletedCheckpoint> savepointFuture1 = 
coord.triggerSavepoint(timestamp, savepointDir);
                long savepointId1 = counter.getLast();
-               CheckpointMetaData checkpointMetaDataS1 = new 
CheckpointMetaData(savepointId1, 0L);
                assertEquals(1, coord.getNumberOfPendingCheckpoints());
 
                assertTrue(coord.triggerCheckpoint(timestamp + 1, false));
@@ -1488,8 +1566,6 @@ public class CheckpointCoordinatorTest {
                long checkpointId2 = counter.getLast();
                assertEquals(3, coord.getNumberOfPendingCheckpoints());
 
-               CheckpointMetaData checkpointMetaData2 = new 
CheckpointMetaData(checkpointId2, 0L);
-
                // 2nd checkpoint should subsume the 1st checkpoint, but not 
the savepoint
                coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, 
attemptID1, checkpointId2));
                coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, 
attemptID2, checkpointId2));
@@ -1505,7 +1581,6 @@ public class CheckpointCoordinatorTest {
 
                Future<CompletedCheckpoint> savepointFuture2 = 
coord.triggerSavepoint(timestamp + 4, savepointDir);
                long savepointId2 = counter.getLast();
-               CheckpointMetaData checkpointMetaDataS2 = new 
CheckpointMetaData(savepointId2, 0L);
                assertEquals(3, coord.getNumberOfPendingCheckpoints());
 
                // 2nd savepoint should subsume the last checkpoint, but not 
the 1st savepoint
@@ -1880,6 +1955,8 @@ public class CheckpointCoordinatorTest {
                ExecutionVertex[] arrayExecutionVertices =
                                allExecutionVertices.toArray(new 
ExecutionVertex[allExecutionVertices.size()]);
 
+               CompletedCheckpointStore store = new 
RecoverableCompletedCheckpointStore();
+
                // set up the coordinator and validate the initial state
                CheckpointCoordinator coord = new CheckpointCoordinator(
                        jid,
@@ -1892,7 +1969,7 @@ public class CheckpointCoordinatorTest {
                        arrayExecutionVertices,
                        arrayExecutionVertices,
                        new StandaloneCheckpointIDCounter(),
-                       new StandaloneCompletedCheckpointStore(1),
+                       store,
                        null,
                        Executors.directExecutor());
 
@@ -1901,38 +1978,32 @@ public class CheckpointCoordinatorTest {
 
                assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
                long checkpointId = 
Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
-               CheckpointMetaData checkpointMetaData = new 
CheckpointMetaData(checkpointId, 0L);
 
                List<KeyGroupRange> keyGroupPartitions1 = 
StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, 
parallelism1);
                List<KeyGroupRange> keyGroupPartitions2 = 
StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, 
parallelism2);
 
                for (int index = 0; index < jobVertex1.getParallelism(); 
index++) {
-                       ChainedStateHandle<StreamStateHandle> 
nonPartitionedState = generateStateForVertex(jobVertexID1, index);
-                       ChainedStateHandle<OperatorStateHandle> 
partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, 
index, 2, 8, false);
-                       KeyGroupsStateHandle partitionedKeyGroupState = 
generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
+                       SubtaskState subtaskState = 
mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index));
 
-                       SubtaskState checkpointStateHandles = new 
SubtaskState(nonPartitionedState, partitionableState, null, 
partitionedKeyGroupState, null);
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
                                        jid,
                                        
jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
                                        checkpointId,
                                        new CheckpointMetrics(),
-                                       checkpointStateHandles);
+                                       subtaskState);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
 
                for (int index = 0; index < jobVertex2.getParallelism(); 
index++) {
-                       ChainedStateHandle<StreamStateHandle> 
nonPartitionedState = generateStateForVertex(jobVertexID2, index);
-                       ChainedStateHandle<OperatorStateHandle> 
partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, 
index, 2, 8, false);
-                       KeyGroupsStateHandle partitionedKeyGroupState = 
generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
-                       SubtaskState checkpointStateHandles = new 
SubtaskState(nonPartitionedState, partitionableState, null, 
partitionedKeyGroupState, null);
+                       SubtaskState subtaskState = 
mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index));
+
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
                                        jid,
                                        
jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
                                        checkpointId,
                                        new CheckpointMetrics(),
-                                       checkpointStateHandles);
+                                       subtaskState);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
@@ -1941,6 +2012,20 @@ public class CheckpointCoordinatorTest {
 
                assertEquals(1, completedCheckpoints.size());
 
+               // shutdown the store
+               SharedStateRegistry sharedStateRegistry = 
coord.getSharedStateRegistry();
+               store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry);
+
+               // All shared states should be unregistered once the store is 
shut down
+               for (CompletedCheckpoint completedCheckpoint : 
completedCheckpoints) {
+                       for (TaskState taskState : 
completedCheckpoint.getTaskStates().values()) {
+                               for (SubtaskState subtaskState : 
taskState.getStates()) {
+                                       verify(subtaskState, 
times(1)).unregisterSharedStates(sharedStateRegistry);
+                               }
+                       }
+               }
+
+               // restore the store
                Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
 
                tasks.put(jobVertexID1, jobVertex1);
@@ -1948,6 +2033,15 @@ public class CheckpointCoordinatorTest {
 
                coord.restoreLatestCheckpointedState(tasks, true, false);
 
+               // validate that all shared states are registered again after 
the recovery.
+               for (CompletedCheckpoint completedCheckpoint : 
completedCheckpoints) {
+                       for (TaskState taskState : 
completedCheckpoint.getTaskStates().values()) {
+                               for (SubtaskState subtaskState : 
taskState.getStates()) {
+                                       verify(subtaskState, 
times(2)).registerSharedStates(sharedStateRegistry);
+                               }
+                       }
+               }
+
                // verify the restored state
                verifyStateRestore(jobVertexID1, jobVertex1, 
keyGroupPartitions1);
                verifyStateRestore(jobVertexID2, jobVertex2, 
keyGroupPartitions2);
@@ -2666,6 +2760,26 @@ public class CheckpointCoordinatorTest {
                return vertex;
        }
 
+       static SubtaskState mockSubtaskState(
+               JobVertexID jobVertexID,
+               int index,
+               KeyGroupRange keyGroupRange) throws IOException {
+
+               ChainedStateHandle<StreamStateHandle> nonPartitionedState = 
generateStateForVertex(jobVertexID, index);
+               ChainedStateHandle<OperatorStateHandle> partitionableState = 
generateChainedPartitionableStateHandle(jobVertexID, index, 2, 8, false);
+               KeyGroupsStateHandle partitionedKeyGroupState = 
generateKeyGroupState(jobVertexID, keyGroupRange, false);
+
+               SubtaskState subtaskState = mock(SubtaskState.class, 
withSettings().serializable());
+
+               
doReturn(nonPartitionedState).when(subtaskState).getLegacyOperatorState();
+               
doReturn(partitionableState).when(subtaskState).getManagedOperatorState();
+               doReturn(null).when(subtaskState).getRawOperatorState();
+               
doReturn(partitionedKeyGroupState).when(subtaskState).getManagedKeyedState();
+               doReturn(null).when(subtaskState).getRawKeyedState();
+
+               return subtaskState;
+       }
+
        public static void verifyStateRestore(
                        JobVertexID jobVertexID, ExecutionJobVertex 
executionJobVertex,
                        List<KeyGroupRange> keyGroupPartitions) throws 
Exception {
@@ -3018,7 +3132,6 @@ public class CheckpointCoordinatorTest {
                ExecutionVertex vertex1 = mockExecutionVertex(new 
ExecutionAttemptID());
 
                StandaloneCompletedCheckpointStore store = new 
StandaloneCompletedCheckpointStore(1);
-               store.addCheckpoint(new CompletedCheckpoint(new JobID(), 0, 0, 
0, Collections.<JobVertexID, TaskState>emptyMap()));
 
                // set up the coordinator and validate the initial state
                CheckpointCoordinator coord = new CheckpointCoordinator(
@@ -3036,6 +3149,10 @@ public class CheckpointCoordinatorTest {
                        null,
                        Executors.directExecutor());
 
+               store.addCheckpoint(
+                       new CompletedCheckpoint(new JobID(), 0, 0, 0, 
Collections.<JobVertexID, TaskState>emptyMap()),
+                       coord.getSharedStateRegistry());
+
                CheckpointStatsTracker tracker = 
mock(CheckpointStatsTracker.class);
                coord.setCheckpointStatsTracker(tracker);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 7e0a7c1..9e372e1 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -255,7 +255,7 @@ public class CheckpointStateRestoreTest {
                }
                CompletedCheckpoint checkpoint = new CompletedCheckpoint(new 
JobID(), 0, 1, 2, new HashMap<>(checkpointTaskStates));
 
-               coord.getCheckpointStore().addCheckpoint(checkpoint);
+               coord.getCheckpointStore().addCheckpoint(checkpoint, 
coord.getSharedStateRegistry());
 
                coord.restoreLatestCheckpointedState(tasks, true, false);
                coord.restoreLatestCheckpointedState(tasks, true, true);
@@ -273,7 +273,7 @@ public class CheckpointStateRestoreTest {
 
                checkpoint = new CompletedCheckpoint(new JobID(), 1, 2, 3, new 
HashMap<>(checkpointTaskStates));
 
-               coord.getCheckpointStore().addCheckpoint(checkpoint);
+               coord.getCheckpointStore().addCheckpoint(checkpoint, 
coord.getSharedStateRegistry());
 
                // (i) Allow non restored state (should succeed)
                coord.restoreLatestCheckpointedState(tasks, true, true);

Reply via email to