Repository: flink
Updated Branches:
  refs/heads/master db31ca3f8 -> aa21f853a


[FLINK-6014] [checkpoint] Additional review changes


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/aa21f853
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/aa21f853
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/aa21f853

Branch: refs/heads/master
Commit: aa21f853ab0380ec1f68ae1d0b7c8d9268da4533
Parents: 218bed8
Author: Stefan Richter <[email protected]>
Authored: Sat Apr 22 01:23:09 2017 +0200
Committer: Stefan Richter <[email protected]>
Committed: Sat Apr 22 15:25:56 2017 +0200

----------------------------------------------------------------------
 .../AbstractCompletedCheckpointStore.java       |  37 ++++
 .../checkpoint/CheckpointCoordinator.java       |  41 ++---
 .../runtime/checkpoint/CompletedCheckpoint.java | 169 +++++++++++--------
 .../checkpoint/CompletedCheckpointStore.java    |  10 +-
 .../runtime/checkpoint/PendingCheckpoint.java   |  12 +-
 .../StandaloneCompletedCheckpointStore.java     |   9 +-
 .../flink/runtime/checkpoint/SubtaskState.java  |   5 -
 .../flink/runtime/checkpoint/TaskState.java     |   7 -
 .../ZooKeeperCompletedCheckpointStore.java      |   8 +-
 .../runtime/state/CompositeStateHandle.java     |  22 +--
 .../flink/runtime/state/SharedStateHandle.java  |   2 +-
 .../runtime/state/SharedStateRegistry.java      |  84 ++++-----
 .../apache/flink/runtime/state/StateObject.java |   6 +-
 .../CheckpointCoordinatorFailureTest.java       |  10 +-
 .../checkpoint/CheckpointCoordinatorTest.java   |  23 +--
 .../checkpoint/CheckpointStateRestoreTest.java  |   5 +-
 .../CompletedCheckpointStoreTest.java           |  28 ++-
 ...ExecutionGraphCheckpointCoordinatorTest.java |   7 +-
 .../checkpoint/PendingCheckpointTest.java       |   4 -
 .../StandaloneCompletedCheckpointStoreTest.java |  26 ++-
 ...ZooKeeperCompletedCheckpointStoreITCase.java |  41 ++---
 .../ZooKeeperCompletedCheckpointStoreTest.java  |  10 +-
 .../runtime/state/SharedStateRegistryTest.java  |  42 +----
 .../RecoverableCompletedCheckpointStore.java    |  13 +-
 24 files changed, 292 insertions(+), 329 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/AbstractCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/AbstractCompletedCheckpointStore.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/AbstractCompletedCheckpointStore.java
new file mode 100644
index 0000000..f42fd06
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/AbstractCompletedCheckpointStore.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.state.SharedStateRegistry;
+
+/**
+ * This is the base class that provides implementation of some aspects common 
for all
+ * {@link CompletedCheckpointStore}s.
+ */
+public abstract class AbstractCompletedCheckpointStore implements 
CompletedCheckpointStore {
+
+       /**
+        * Registry for shared states.
+        */
+       protected final SharedStateRegistry sharedStateRegistry;
+
+       public AbstractCompletedCheckpointStore() {
+               this.sharedStateRegistry = new SharedStateRegistry();
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index 5309dd4..256321e 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -37,9 +37,7 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.TaskStateHandles;
-
 import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
 import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
@@ -110,9 +108,6 @@ public class CheckpointCoordinator {
        /** Completed checkpoints. Implementations can be blocking. Make sure 
calls to methods
         * accessing this don't block the job manager actor and run 
asynchronously. */
        private final CompletedCheckpointStore completedCheckpointStore;
-       
-       /** Registry for shared states */
-       private final SharedStateRegistry sharedStateRegistry;
 
        /** Default directory for persistent checkpoints; <code>null</code> if 
none configured.
         * THIS WILL BE REPLACED BY PROPER STATE-BACKEND METADATA WRITING */
@@ -223,7 +218,6 @@ public class CheckpointCoordinator {
                this.completedCheckpointStore = 
checkNotNull(completedCheckpointStore);
                this.checkpointDirectory = checkpointDirectory;
                this.executor = checkNotNull(executor);
-               this.sharedStateRegistry = new SharedStateRegistry();
 
                this.recentPendingCheckpoints = new 
ArrayDeque<>(NUM_GHOST_CHECKPOINT_IDS);
 
@@ -288,7 +282,7 @@ public class CheckpointCoordinator {
                                }
                                pendingCheckpoints.clear();
 
-                               completedCheckpointStore.shutdown(jobStatus, 
sharedStateRegistry);
+                               completedCheckpointStore.shutdown(jobStatus);
                                checkpointIdCounter.shutdown(jobStatus);
                        }
                }
@@ -732,7 +726,7 @@ public class CheckpointCoordinator {
                                                                "the state 
handle to avoid lingering state.", message.getCheckpointId(),
                                                        
message.getTaskExecutionId(), message.getJob());
 
-                                               discardState(message.getJob(), 
message.getTaskExecutionId(), message.getCheckpointId(), 
message.getSubtaskState());
+                                               
discardSubtaskState(message.getJob(), message.getTaskExecutionId(), 
message.getCheckpointId(), message.getSubtaskState());
 
                                                break;
                                        case DISCARDED:
@@ -741,7 +735,7 @@ public class CheckpointCoordinator {
                                                                "state handle 
tp avoid lingering state.",
                                                        
message.getCheckpointId(), message.getTaskExecutionId(), message.getJob());
 
-                                               discardState(message.getJob(), 
message.getTaskExecutionId(), message.getCheckpointId(), 
message.getSubtaskState());
+                                               
discardSubtaskState(message.getJob(), message.getTaskExecutionId(), 
message.getCheckpointId(), message.getSubtaskState());
                                }
 
                                return true;
@@ -767,7 +761,7 @@ public class CheckpointCoordinator {
                                }
 
                                // try to discard the state so that we don't 
have lingering state lying around
-                               discardState(message.getJob(), 
message.getTaskExecutionId(), message.getCheckpointId(), 
message.getSubtaskState());
+                               discardSubtaskState(message.getJob(), 
message.getTaskExecutionId(), message.getCheckpointId(), 
message.getSubtaskState());
 
                                return wasPendingCheckpoint;
                        }
@@ -805,16 +799,16 @@ public class CheckpointCoordinator {
        
                        // the pending checkpoint must be discarded after the 
finalization
                        
Preconditions.checkState(pendingCheckpoint.isDiscarded() && completedCheckpoint 
!= null);
-       
+
                        try {
-                               
completedCheckpointStore.addCheckpoint(completedCheckpoint, 
sharedStateRegistry);
+                               
completedCheckpointStore.addCheckpoint(completedCheckpoint);
                        } catch (Exception exception) {
                                // we failed to store the completed checkpoint. 
Let's clean up
                                executor.execute(new Runnable() {
                                        @Override
                                        public void run() {
                                                try {
-                                                       
completedCheckpoint.discardOnFail();
+                                                       
completedCheckpoint.discardOnFailedStoring();
                                                } catch (Throwable t) {
                                                        LOG.warn("Could not 
properly discard completed checkpoint {}.", 
completedCheckpoint.getCheckpointID(), t);
                                                }
@@ -953,7 +947,7 @@ public class CheckpointCoordinator {
                        }
 
                        // Recover the checkpoints
-                       completedCheckpointStore.recover(sharedStateRegistry);
+                       completedCheckpointStore.recover();
 
                        // restore from the latest checkpoint
                        CompletedCheckpoint latest = 
completedCheckpointStore.getLatestCheckpoint();
@@ -1017,7 +1011,7 @@ public class CheckpointCoordinator {
                CompletedCheckpoint savepoint = 
SavepointLoader.loadAndValidateSavepoint(
                                job, tasks, savepointPath, userClassLoader, 
allowNonRestored);
 
-               completedCheckpointStore.addCheckpoint(savepoint, 
sharedStateRegistry);
+               completedCheckpointStore.addCheckpoint(savepoint);
                
                // Reset the checkpoint ID counter
                long nextCheckpointId = savepoint.getCheckpointID() + 1;
@@ -1057,10 +1051,11 @@ public class CheckpointCoordinator {
        public CompletedCheckpointStore getCheckpointStore() {
                return completedCheckpointStore;
        }
-       
-       public SharedStateRegistry getSharedStateRegistry() {
-               return sharedStateRegistry;
-       }
+
+//     @VisibleForTesting
+//     SharedStateRegistry getSharedStateRegistry() {
+//             return sharedStateRegistry;
+//     }
 
        public CheckpointIDCounter getCheckpointIdCounter() {
                return checkpointIdCounter;
@@ -1151,7 +1146,7 @@ public class CheckpointCoordinator {
         * @param checkpointId of the state object
         * @param subtaskState to discard asynchronously
         */
-       private void discardState(
+       private void discardSubtaskState(
                        final JobID jobId,
                        final ExecutionAttemptID executionAttemptID,
                        final long checkpointId,
@@ -1161,12 +1156,6 @@ public class CheckpointCoordinator {
                        executor.execute(new Runnable() {
                                @Override
                                public void run() {
-                                       try {
-                                               
subtaskState.discardSharedStatesOnFail();
-                                       } catch (Throwable t1) {
-                                               LOG.warn("Could not properly 
discard shared states of checkpoint {} " +
-                                                       "belonging to task {} 
of job {}.", checkpointId, executionAttemptID, jobId, t1);
-                                       }
 
                                        try {
                                                subtaskState.discardState();

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
index 58e91e1..79fc31f 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
@@ -20,14 +20,12 @@ package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.JobID;
-import 
org.apache.flink.runtime.checkpoint.CompletedCheckpointStats.DiscardCallback;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.ExceptionUtils;
-
 import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -107,7 +105,7 @@ public class CompletedCheckpoint implements Serializable {
 
        /** Optional stats tracker callback for discard. */
        @Nullable
-       private transient volatile DiscardCallback discardCallback;
+       private transient volatile CompletedCheckpointStats.DiscardCallback 
discardCallback;
 
        // 
------------------------------------------------------------------------
 
@@ -151,7 +149,7 @@ public class CompletedCheckpoint implements Serializable {
                checkArgument((externalPointer == null) == 
(externalizedMetadata == null),
                                "external pointer without externalized metadata 
must be both null or both non-null");
 
-               checkArgument(!props.externalizeCheckpoint() || externalPointer 
!= null, 
+               checkArgument(!props.externalizeCheckpoint() || externalPointer 
!= null,
                        "Checkpoint properties require externalized checkpoint, 
but checkpoint is not externalized");
 
                this.job = checkNotNull(job);
@@ -186,15 +184,14 @@ public class CompletedCheckpoint implements Serializable {
                return props;
        }
 
-       public void discardOnFail() throws Exception {
-               discard(null, true);
+       public void discardOnFailedStoring() throws Exception {
+               new UnstoredDiscardStategy().discard();
        }
 
        public boolean discardOnSubsume(SharedStateRegistry 
sharedStateRegistry) throws Exception {
-               Preconditions.checkNotNull(sharedStateRegistry, "The registry 
cannot be null.");
 
                if (props.discardOnSubsumed()) {
-                       discard(sharedStateRegistry, false);
+                       new 
StoredDiscardStrategy(sharedStateRegistry).discard();
                        return true;
                }
 
@@ -202,14 +199,13 @@ public class CompletedCheckpoint implements Serializable {
        }
 
        public boolean discardOnShutdown(JobStatus jobStatus, 
SharedStateRegistry sharedStateRegistry) throws Exception {
-               Preconditions.checkNotNull(sharedStateRegistry, "The registry 
cannot be null.");
 
                if (jobStatus == JobStatus.FINISHED && 
props.discardOnJobFinished() ||
                                jobStatus == JobStatus.CANCELED && 
props.discardOnJobCancelled() ||
                                jobStatus == JobStatus.FAILED && 
props.discardOnJobFailed() ||
                                jobStatus == JobStatus.SUSPENDED && 
props.discardOnJobSuspended()) {
 
-                       discard(sharedStateRegistry, false);
+                       new 
StoredDiscardStrategy(sharedStateRegistry).discard();
                        return true;
                } else {
                        if (externalPointer != null) {
@@ -221,53 +217,6 @@ public class CompletedCheckpoint implements Serializable {
                }
        }
 
-       private void discard(SharedStateRegistry sharedStateRegistry, boolean 
failed) throws Exception {
-               Preconditions.checkState(failed || (sharedStateRegistry != 
null),
-                       "The registry must not be null if the complete 
checkpoint does not fail.");
-
-               try {
-                       // collect exceptions and continue cleanup
-                       Exception exception = null;
-
-                       // drop the metadata, if we have some
-                       if (externalizedMetadata != null) {
-                               try {
-                                       externalizedMetadata.discardState();
-                               } catch (Exception e) {
-                                       exception = e;
-                               }
-                       }
-
-                       // In the cases where the completed checkpoint fails, 
the shared
-                       // states have not been registered to the registry. 
It's the state
-                       // handles' responsibility to discard their shared 
states.
-                       if (!failed) {
-                               unregisterSharedStates(sharedStateRegistry);
-                       } else {
-                               discardSharedStatesOnFail();
-                       }
-
-                       // discard private state objects
-                       try {
-                               
StateUtil.bestEffortDiscardAllStateObjects(taskStates.values());
-                       } catch (Exception e) {
-                               exception = ExceptionUtils.firstOrSuppressed(e, 
exception);
-                       }
-
-                       if (exception != null) {
-                               throw exception;
-                       }
-               } finally {
-                       taskStates.clear();
-
-                       // to be null-pointer safe, copy reference to stack
-                       DiscardCallback discardCallback = this.discardCallback;
-                       if (discardCallback != null) {
-                               discardCallback.notifyDiscardedCheckpoint();
-                       }
-               }
-       }
-
        public long getStateSize() {
                long result = 0L;
 
@@ -319,30 +268,108 @@ public class CompletedCheckpoint implements Serializable 
{
                sharedStateRegistry.registerAll(taskStates.values());
        }
 
+       // 
--------------------------------------------------------------------------------------------
+
+       @Override
+       public String toString() {
+               return String.format("Checkpoint %d @ %d for %s", checkpointID, 
timestamp, job);
+       }
+
        /**
-        * Unregister all shared states from the given registry. This is method 
is
-        * called when the completed checkpoint is subsumed or the job 
terminates.
-        *
-        * @param sharedStateRegistry The registry where shared states are 
registered
+        * Base class for the discarding strategies of {@link 
CompletedCheckpoint}.
         */
-       private void unregisterSharedStates(SharedStateRegistry 
sharedStateRegistry) {
-               sharedStateRegistry.unregisterAll(taskStates.values());
+       private abstract class DiscardStrategy {
+
+               protected Exception storedException;
+
+               public DiscardStrategy() {
+                       this.storedException = null;
+               }
+
+               public void discard() throws Exception {
+
+                       try {
+                               // collect exceptions and continue cleanup
+                               storedException = null;
+
+                               doDiscardExternalizedMetaData();
+                               doDiscardSharedState();
+                               doDiscardPrivateState();
+                               doReportStoredExceptions();
+                       } finally {
+                               clearTaskStatesAndNotifyDiscardCompleted();
+                       }
+               }
+
+               protected void doDiscardExternalizedMetaData() {
+                       // drop the metadata, if we have some
+                       if (externalizedMetadata != null) {
+                               try {
+                                       externalizedMetadata.discardState();
+                               } catch (Exception e) {
+                                       storedException = e;
+                               }
+                       }
+               }
+
+               protected void doDiscardPrivateState() {
+                       // discard private state objects
+                       try {
+                               
StateUtil.bestEffortDiscardAllStateObjects(taskStates.values());
+                       } catch (Exception e) {
+                               storedException = 
ExceptionUtils.firstOrSuppressed(e, storedException);
+                       }
+               }
+
+               protected abstract void doDiscardSharedState();
+
+               protected void doReportStoredExceptions() throws Exception {
+                       if (storedException != null) {
+                               throw storedException;
+                       }
+               }
+
+               protected void clearTaskStatesAndNotifyDiscardCompleted() {
+                       taskStates.clear();
+                       // to be null-pointer safe, copy reference to stack
+                       CompletedCheckpointStats.DiscardCallback 
discardCallback =
+                               CompletedCheckpoint.this.discardCallback;
+
+                       if (discardCallback != null) {
+                               discardCallback.notifyDiscardedCheckpoint();
+                       }
+               }
        }
 
        /**
-        * Discard all shared states created in the checkpoint. This method is 
called
+        * Discard all shared states created in the checkpoint. This strategy 
is applied
         * when the completed checkpoint fails to be added into the store.
         */
-       private void discardSharedStatesOnFail() throws Exception {
-               for (TaskState taskState : taskStates.values()) {
-                       taskState.discardSharedStatesOnFail();
+       private class UnstoredDiscardStategy extends 
CompletedCheckpoint.DiscardStrategy {
+
+               @Override
+               protected void doDiscardSharedState() {
+                       // nothing to do because we did not register any shared 
state yet. unregistered, new
+                       // shared state is then still considered private state 
and deleted as part of
+                       // doDiscardPrivateState().
                }
        }
 
-       // 
--------------------------------------------------------------------------------------------
+       /**
+        * Unregister all shared states from the given registry. This is 
strategy is
+        * applied when the completed checkpoint is subsumed or the job 
terminates.
+        */
+       private class StoredDiscardStrategy extends 
CompletedCheckpoint.DiscardStrategy {
 
-       @Override
-       public String toString() {
-               return String.format("Checkpoint %d @ %d for %s", checkpointID, 
timestamp, job);
+               SharedStateRegistry sharedStateRegistry;
+
+               public StoredDiscardStrategy(SharedStateRegistry 
sharedStateRegistry) {
+                       this.sharedStateRegistry = 
Preconditions.checkNotNull(sharedStateRegistry);
+               }
+
+               @Override
+               protected void doDiscardSharedState() {
+                       sharedStateRegistry.unregisterAll(taskStates.values());
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
index 0ade25c..82193b5 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
@@ -19,7 +19,6 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.jobgraph.JobStatus;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 
 import java.util.List;
 
@@ -34,16 +33,15 @@ public interface CompletedCheckpointStore {
         * <p>After a call to this method, {@link #getLatestCheckpoint()} 
returns the latest
         * available checkpoint.
         */
-       void recover(SharedStateRegistry sharedStateRegistry) throws Exception;
+       void recover() throws Exception;
 
        /**
         * Adds a {@link CompletedCheckpoint} instance to the list of completed 
checkpoints.
         *
         * <p>Only a bounded number of checkpoints is kept. When exceeding the 
maximum number of
-        * retained checkpoints, the oldest one will be discarded via {@link
-        * CompletedCheckpoint#discardOnSubsume(SharedStateRegistry)} )}.
+        * retained checkpoints, the oldest one will be discarded.
         */
-       void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry 
sharedStateRegistry) throws Exception;
+       void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception;
 
        /**
         * Returns the latest {@link CompletedCheckpoint} instance or 
<code>null</code> if none was
@@ -59,7 +57,7 @@ public interface CompletedCheckpointStore {
         *
         * @param jobStatus Job state on shut down
         */
-       void shutdown(JobStatus jobStatus, SharedStateRegistry 
sharedStateRegistry) throws Exception;
+       void shutdown(JobStatus jobStatus) throws Exception;
 
        /**
         * Returns all {@link CompletedCheckpoint} instances.

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
index 6805dea..900331b 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
@@ -506,16 +506,8 @@ public class PendingCheckpoint {
                                                @Override
                                                public void run() {
 
-                                                       // discard the shared 
states that are created in the checkpoint
-                                                       for (TaskState 
taskState : taskStates.values()) {
-                                                               try {
-                                                                       
taskState.discardSharedStatesOnFail();
-                                                               } catch 
(Throwable t) {
-                                                                       
LOG.warn("Could not properly dispose unreferenced shared states.");
-                                                               }
-                                                       }
-
-                                                       // discard the private 
states
+                                                       // discard the private 
states.
+                                                       // unregistered shared 
states are still considered private at this point.
                                                        try {
                                                                
StateUtil.bestEffortDiscardAllStateObjects(taskStates.values());
                                                        } catch (Throwable t) {

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
index 9f833c3..f5e1db3 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
@@ -20,7 +20,6 @@ package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobmanager.HighAvailabilityMode;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -33,7 +32,7 @@ import static 
org.apache.flink.util.Preconditions.checkArgument;
 /**
  * {@link CompletedCheckpointStore} for JobManagers running in {@link 
HighAvailabilityMode#NONE}.
  */
-public class StandaloneCompletedCheckpointStore implements 
CompletedCheckpointStore {
+public class StandaloneCompletedCheckpointStore extends 
AbstractCompletedCheckpointStore {
 
        private static final Logger LOG = 
LoggerFactory.getLogger(StandaloneCompletedCheckpointStore.class);
 
@@ -57,12 +56,12 @@ public class StandaloneCompletedCheckpointStore implements 
CompletedCheckpointSt
        }
 
        @Override
-       public void recover(SharedStateRegistry sharedStateRegistry) throws 
Exception {
+       public void recover() throws Exception {
                // Nothing to do
        }
 
        @Override
-       public void addCheckpoint(CompletedCheckpoint checkpoint, 
SharedStateRegistry sharedStateRegistry) throws Exception {
+       public void addCheckpoint(CompletedCheckpoint checkpoint) throws 
Exception {
                
                checkpoints.addLast(checkpoint);
 
@@ -99,7 +98,7 @@ public class StandaloneCompletedCheckpointStore implements 
CompletedCheckpointSt
        }
 
        @Override
-       public void shutdown(JobStatus jobStatus, SharedStateRegistry 
sharedStateRegistry) throws Exception {
+       public void shutdown(JobStatus jobStatus) throws Exception {
                try {
                        LOG.info("Shutting down");
 

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
index e968643..121ac57 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
@@ -161,11 +161,6 @@ public class SubtaskState implements CompositeStateHandle {
        }
 
        @Override
-       public void discardSharedStatesOnFail() {
-               // No shared states
-       }
-
-       @Override
        public long getStateSize() {
                return stateSize;
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
index 19fe962..4f5f536 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
@@ -144,13 +144,6 @@ public class TaskState implements CompositeStateHandle {
        }
 
        @Override
-       public void discardSharedStatesOnFail() {
-               for (SubtaskState subtaskState : subtaskStates.values()) {
-                       subtaskState.discardSharedStatesOnFail();
-               }
-       }
-
-       @Override
        public long getStateSize() {
                long result = 0L;
 

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
index 07546ea..52a4eea 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
@@ -68,7 +68,7 @@ import static 
org.apache.flink.util.Preconditions.checkNotNull;
  * checkpoints is consistent. Currently, after recovery we start out with only 
a single
  * checkpoint to circumvent those situations.
  */
-public class ZooKeeperCompletedCheckpointStore implements 
CompletedCheckpointStore {
+public class ZooKeeperCompletedCheckpointStore extends 
AbstractCompletedCheckpointStore {
 
        private static final Logger LOG = 
LoggerFactory.getLogger(ZooKeeperCompletedCheckpointStore.class);
 
@@ -141,7 +141,7 @@ public class ZooKeeperCompletedCheckpointStore implements 
CompletedCheckpointSto
         * that the history of checkpoints is consistent.
         */
        @Override
-       public void recover(SharedStateRegistry sharedStateRegistry) throws 
Exception {
+       public void recover() throws Exception {
                LOG.info("Recovering checkpoints from ZooKeeper.");
 
                // Clear local handles in order to prevent duplicates on
@@ -192,7 +192,7 @@ public class ZooKeeperCompletedCheckpointStore implements 
CompletedCheckpointSto
         * @param checkpoint Completed checkpoint to add.
         */
        @Override
-       public void addCheckpoint(final CompletedCheckpoint checkpoint, final 
SharedStateRegistry sharedStateRegistry) throws Exception {
+       public void addCheckpoint(final CompletedCheckpoint checkpoint) throws 
Exception {
                checkNotNull(checkpoint, "Checkpoint");
                
                final String path = 
checkpointIdToPath(checkpoint.getCheckpointID());
@@ -281,7 +281,7 @@ public class ZooKeeperCompletedCheckpointStore implements 
CompletedCheckpointSto
        }
 
        @Override
-       public void shutdown(JobStatus jobStatus, SharedStateRegistry 
sharedStateRegistry) throws Exception {
+       public void shutdown(JobStatus jobStatus) throws Exception {
                if (jobStatus.isGloballyTerminalState()) {
                        LOG.info("Shutting down");
 

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
index 2ea5bc9..002b7c3 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
@@ -28,16 +28,24 @@ package org.apache.flink.runtime.state;
  * received by the {@link 
org.apache.flink.runtime.checkpoint.CheckpointCoordinator}
  * and will be discarded when the checkpoint is discarded.
  * 
- * <p>The {@link SharedStateRegistry} is responsible for the discarding of the
- * shared states. The composite state handle should only delete those private
- * states in the {@link StateObject#discardState()} method.
+ * <p>The {@link SharedStateRegistry} is responsible for the discarding of 
registered
+ * shared states. Before their first registration through
+ * {@link #registerSharedStates(SharedStateRegistry)}, newly created shared 
state is still owned by
+ * this handle and considered as private state until it is registered for the 
first time. Registration
+ * transfers ownership to the {@link SharedStateRegistry}.
+ * The composite state handle should only delete all private states in the
+ * {@link StateObject#discardState()} method.
  */
 public interface CompositeStateHandle extends StateObject {
 
        /**
-        * Register both created and referenced shared states in the given
+        * Register both newly created and already referenced shared states in 
the given
         * {@link SharedStateRegistry}. This method is called when the 
checkpoint
         * successfully completes or is recovered from failures.
+        * <p>
+        * After this is completed, newly created shared state is considered as 
published is no longer
+        * owned by this handle. This means that it should no longer be deleted 
as part of calls to
+        * {@link #discardState()}.
         *
         * @param stateRegistry The registry where shared states are registered.
         */
@@ -51,10 +59,4 @@ public interface CompositeStateHandle extends StateObject {
         * @param stateRegistry The registry where shared states are registered.
         */
        void unregisterSharedStates(SharedStateRegistry stateRegistry);
-
-       /**
-        * Discard all shared states created in this checkpoint. This method is
-        * called when the checkpoint fails to complete.
-        */
-       void discardSharedStatesOnFail() throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
index f856052..c8c4046 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
@@ -35,5 +35,5 @@ public interface SharedStateHandle extends StateObject {
        /**
         * Return the identifier of the shared state.
         */
-       String getKey();
+       String getRegistrationKey();
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
index b5048d0..2cb43ac 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
@@ -18,13 +18,10 @@
 
 package org.apache.flink.runtime.state;
 
-import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.Serializable;
-import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -33,73 +30,78 @@ import java.util.Map;
  * {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator} to 
  * maintain the reference count of {@link SharedStateHandle}s which are shared
  * among different checkpoints.
+ *
  */
-public class SharedStateRegistry implements Serializable {
+public class SharedStateRegistry {
 
-       private static Logger LOG = 
LoggerFactory.getLogger(SharedStateRegistry.class);
+       private static final Logger LOG = 
LoggerFactory.getLogger(SharedStateRegistry.class);
 
-       private static final long serialVersionUID = -8357254413007773970L;
+       /** All registered state objects by an artificial key */
+       private final Map<String, SharedStateRegistry.SharedStateEntry> 
registeredStates;
 
-       /** All registered state objects */
-       private final Map<String, SharedStateEntry> registeredStates = new 
HashMap<>();
+       public SharedStateRegistry() {
+               this.registeredStates = new HashMap<>();
+       }
 
        /**
-        * Register the state in the registry
+        * Register a reference to the given shared state in the registry. This 
increases the reference
+        * count for the this shared state by one. Returns the reference count 
after the update.
         *
-        * @param state The state to register
-        * @param isNew True if the shared state is newly created
+        * @param state the shared state for which we register a reference.
+        * @return the updated reference count for the given shared state.
         */
-       public void register(SharedStateHandle state, boolean isNew) {
+       public int register(SharedStateHandle state) {
                if (state == null) {
-                       return;
+                       return 0;
                }
 
                synchronized (registeredStates) {
-                       SharedStateEntry entry = 
registeredStates.get(state.getKey());
-
-                       if (isNew) {
-                               Preconditions.checkState(entry == null,
-                                       "The state cannot be created more than 
once.");
+                       SharedStateRegistry.SharedStateEntry entry =
+                               
registeredStates.get(state.getRegistrationKey());
 
-                               registeredStates.put(state.getKey(), new 
SharedStateEntry(state));
+                       if (entry == null) {
+                               SharedStateRegistry.SharedStateEntry stateEntry 
=
+                                       new 
SharedStateRegistry.SharedStateEntry(state);
+                               
registeredStates.put(state.getRegistrationKey(), stateEntry);
+                               return 1;
                        } else {
-                               Preconditions.checkState(entry != null,
-                                       "The state cannot be referenced if it 
has not been created yet.");
-
                                entry.increaseReferenceCount();
+                               return entry.getReferenceCount();
                        }
                }
        }
 
        /**
-        * Unregister the state in the registry
+        * Unregister one reference to the given shared state in the registry. 
This decreases the
+        * reference count by one. Once the count reaches zero, the shared 
state is deleted.
         *
-        * @param state The state to unregister
+        * @param state the shared state for which we unregister a reference.
+        * @return the reference count for the shared state after the update.
         */
-       public void unregister(SharedStateHandle state) {
+       public int unregister(SharedStateHandle state) {
                if (state == null) {
-                       return;
+                       return 0;
                }
 
                synchronized (registeredStates) {
-                       SharedStateEntry entry = 
registeredStates.get(state.getKey());
+                       SharedStateRegistry.SharedStateEntry entry = 
registeredStates.get(state.getRegistrationKey());
 
-                       if (entry == null) {
-                               throw new IllegalStateException("Cannot 
unregister an unexisted state.");
-                       }
+                       Preconditions.checkState(entry != null, "Cannot 
unregister a state that is not registered.");
 
                        entry.decreaseReferenceCount();
 
-                       // Remove the state from the registry when it's not 
referenced any more.
-                       if (entry.getReferenceCount() == 0) {
-                               registeredStates.remove(state.getKey());
+                       final int newReferenceCount = entry.getReferenceCount();
 
+                       // Remove the state from the registry when it's not 
referenced any more.
+                       if (newReferenceCount <= 0) {
+                               
registeredStates.remove(state.getRegistrationKey());
                                try {
                                        entry.getState().discardState();
                                } catch (Exception e) {
-                                       LOG.warn("Cannot properly discard the 
state " + entry.getState() + ".", e);
+                                       LOG.warn("Cannot properly discard the 
state {}.", entry.getState(), e);
                                }
                        }
+                       return newReferenceCount;
                }
        }
 
@@ -108,7 +110,7 @@ public class SharedStateRegistry implements Serializable {
         *
         * @param stateHandles The shared states to register.
         */
-       public void registerAll(Collection<? extends CompositeStateHandle> 
stateHandles) {
+       public void registerAll(Iterable<? extends CompositeStateHandle> 
stateHandles) {
                if (stateHandles == null) {
                        return;
                }
@@ -127,7 +129,7 @@ public class SharedStateRegistry implements Serializable {
         *
         * @param stateHandles The shared states to unregister.
         */
-       public void unregisterAll(Collection<? extends CompositeStateHandle> 
stateHandles) {
+       public void unregisterAll(Iterable<? extends CompositeStateHandle> 
stateHandles) {
                if (stateHandles == null) {
                        return;
                }
@@ -140,6 +142,7 @@ public class SharedStateRegistry implements Serializable {
        }
 
        private static class SharedStateEntry {
+
                /** The shared object */
                private final SharedStateHandle state;
 
@@ -168,10 +171,13 @@ public class SharedStateRegistry implements Serializable {
                }
        }
 
-
-       @VisibleForTesting
        public int getReferenceCount(SharedStateHandle state) {
-               SharedStateEntry entry = registeredStates.get(state.getKey());
+               if (state == null) {
+                       return 0;
+               }
+
+               SharedStateRegistry.SharedStateEntry entry =
+                       registeredStates.get(state.getRegistrationKey());
 
                return entry == null ? 0 : entry.getReferenceCount();
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
index 7f1dd18..3b49df7 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
@@ -18,6 +18,8 @@
 
 package org.apache.flink.runtime.state;
 
+import java.io.Serializable;
+
 /**
  * Base of all handles that represent checkpointed state in some form. The 
object may hold
  * the (small) state directly, or contain a file path (state is in the file), 
or contain the
@@ -33,10 +35,10 @@ package org.apache.flink.runtime.state;
  * compatibility, they are not stored via {@link java.io.Serializable Java 
Serialization},
  * but through custom serializers.
  */
-public interface StateObject extends java.io.Serializable {
+public interface StateObject extends Serializable {
 
        /**
-        * Discards the state referred to by this handle, to free up resources 
in
+        * Discards the state referred to and solemnly owned by this handle, to 
free up resources in
         * the persistent storage. This method is called when the state 
represented by this
         * object will not be used any more.
         */

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
index 632f2c0..90b7fe7 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
@@ -25,7 +25,6 @@ import 
org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.util.TestLogger;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -39,9 +38,7 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
-import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 
 @RunWith(PowerMockRunner.class)
@@ -105,19 +102,18 @@ public class CheckpointCoordinatorFailureTest extends 
TestLogger {
                assertTrue(pendingCheckpoint.isDiscarded());
 
                // make sure that the subtask state has been discarded after we 
could not complete it.
-               verify(subtaskState, times(1)).discardSharedStatesOnFail();
                verify(subtaskState).discardState();
        }
 
        private static final class FailingCompletedCheckpointStore implements 
CompletedCheckpointStore {
 
                @Override
-               public void recover(SharedStateRegistry sharedStateRegistry) 
throws Exception {
+               public void recover() throws Exception {
                        throw new UnsupportedOperationException("Not 
implemented.");
                }
 
                @Override
-               public void addCheckpoint(CompletedCheckpoint checkpoint, 
SharedStateRegistry sharedStateRegistry) throws Exception {
+               public void addCheckpoint(CompletedCheckpoint checkpoint) 
throws Exception {
                        throw new Exception("The failing completed checkpoint 
store failed again... :-(");
                }
 
@@ -127,7 +123,7 @@ public class CheckpointCoordinatorFailureTest extends 
TestLogger {
                }
 
                @Override
-               public void shutdown(JobStatus jobStatus, SharedStateRegistry 
sharedStateRegistry) throws Exception {
+               public void shutdown(JobStatus jobStatus) throws Exception {
                        throw new UnsupportedOperationException("Not 
implemented.");
                }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index fabf3fc..24169f2 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -88,7 +88,6 @@ import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyLong;
 import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
@@ -880,8 +879,6 @@ public class CheckpointCoordinatorTest {
                        assertEquals(1, 
coord.getNumberOfRetainedSuccessfulCheckpoints());
 
                        // validate that all received subtask states in the 
first checkpoint have been discarded
-                       verify(subtaskState1_1, 
times(1)).discardSharedStatesOnFail();
-                       verify(subtaskState1_2, 
times(1)).discardSharedStatesOnFail();
                        verify(subtaskState1_1, times(1)).discardState();
                        verify(subtaskState1_2, times(1)).discardState();
 
@@ -907,7 +904,6 @@ public class CheckpointCoordinatorTest {
                        // send the last remaining ack for the first 
checkpoint. This should not do anything
                        SubtaskState subtaskState1_3 = mock(SubtaskState.class);
                        coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new 
CheckpointMetrics(), subtaskState1_3));
-                       verify(subtaskState1_3, 
times(1)).discardSharedStatesOnFail();
                        verify(subtaskState1_3, times(1)).discardState();
 
                        coord.shutdown(JobStatus.FINISHED);
@@ -993,7 +989,6 @@ public class CheckpointCoordinatorTest {
                        assertEquals(0, 
coord.getNumberOfRetainedSuccessfulCheckpoints());
 
                        // validate that the received states have been discarded
-                       verify(subtaskState, 
times(1)).discardSharedStatesOnFail();
                        verify(subtaskState, times(1)).discardState();
 
                        // no confirm message must have been sent
@@ -1117,7 +1112,6 @@ public class CheckpointCoordinatorTest {
                coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new 
CheckpointMetrics(), triggerSubtaskState));
 
                // verify that the subtask state has registered its shared 
states at the registry
-               verify(triggerSubtaskState, 
never()).discardSharedStatesOnFail();
                verify(triggerSubtaskState, never()).discardState();
 
                SubtaskState unknownSubtaskState = mock(SubtaskState.class);
@@ -1126,7 +1120,6 @@ public class CheckpointCoordinatorTest {
                coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new 
CheckpointMetrics(), unknownSubtaskState));
 
                // we should discard acknowledge messages from an unknown 
vertex belonging to our job
-               verify(unknownSubtaskState, 
times(1)).discardSharedStatesOnFail();
                verify(unknownSubtaskState, times(1)).discardState();
 
                SubtaskState differentJobSubtaskState = 
mock(SubtaskState.class);
@@ -1135,7 +1128,6 @@ public class CheckpointCoordinatorTest {
                coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new 
JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), 
differentJobSubtaskState));
 
                // we should not interfere with different jobs
-               verify(differentJobSubtaskState, 
never()).discardSharedStatesOnFail();
                verify(differentJobSubtaskState, never()).discardState();
 
                // duplicate acknowledge message for the trigger vertex
@@ -1143,7 +1135,6 @@ public class CheckpointCoordinatorTest {
                coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new 
CheckpointMetrics(), triggerSubtaskState));
 
                // duplicate acknowledge messages for a known vertex should not 
trigger discarding the state
-               verify(triggerSubtaskState, 
never()).discardSharedStatesOnFail();
                verify(triggerSubtaskState, never()).discardState();
 
                // let the checkpoint fail at the first ack vertex
@@ -1153,7 +1144,6 @@ public class CheckpointCoordinatorTest {
                assertTrue(pendingCheckpoint.isDiscarded());
 
                // check that we've cleaned up the already acknowledged state
-               verify(triggerSubtaskState, 
times(1)).discardSharedStatesOnFail();
                verify(triggerSubtaskState, times(1)).discardState();
 
                SubtaskState ackSubtaskState = mock(SubtaskState.class);
@@ -1162,7 +1152,6 @@ public class CheckpointCoordinatorTest {
                coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, ackAttemptId2, checkpointId, new 
CheckpointMetrics(), ackSubtaskState));
 
                // check that we also cleaned up this state
-               verify(ackSubtaskState, times(1)).discardSharedStatesOnFail();
                verify(ackSubtaskState, times(1)).discardState();
 
                // receive an acknowledge message from an unknown job
@@ -1170,7 +1159,6 @@ public class CheckpointCoordinatorTest {
                coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new 
JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), 
differentJobSubtaskState));
 
                // we should not interfere with different jobs
-               verify(differentJobSubtaskState, 
never()).discardSharedStatesOnFail();
                verify(differentJobSubtaskState, never()).discardState();
 
                SubtaskState unknownSubtaskState2 = mock(SubtaskState.class);
@@ -1179,7 +1167,6 @@ public class CheckpointCoordinatorTest {
                coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new 
CheckpointMetrics(), unknownSubtaskState2));
 
                // we should discard acknowledge messages from an unknown 
vertex belonging to our job
-               verify(unknownSubtaskState2, 
times(1)).discardSharedStatesOnFail();
                verify(unknownSubtaskState2, times(1)).discardState();
        }
 
@@ -2013,14 +2000,13 @@ public class CheckpointCoordinatorTest {
                assertEquals(1, completedCheckpoints.size());
 
                // shutdown the store
-               SharedStateRegistry sharedStateRegistry = 
coord.getSharedStateRegistry();
-               store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry);
+               store.shutdown(JobStatus.SUSPENDED);
 
                // All shared states should be unregistered once the store is 
shut down
                for (CompletedCheckpoint completedCheckpoint : 
completedCheckpoints) {
                        for (TaskState taskState : 
completedCheckpoint.getTaskStates().values()) {
                                for (SubtaskState subtaskState : 
taskState.getStates()) {
-                                       verify(subtaskState, 
times(1)).unregisterSharedStates(sharedStateRegistry);
+                                       verify(subtaskState, 
times(1)).unregisterSharedStates(any(SharedStateRegistry.class));
                                }
                        }
                }
@@ -2037,7 +2023,7 @@ public class CheckpointCoordinatorTest {
                for (CompletedCheckpoint completedCheckpoint : 
completedCheckpoints) {
                        for (TaskState taskState : 
completedCheckpoint.getTaskStates().values()) {
                                for (SubtaskState subtaskState : 
taskState.getStates()) {
-                                       verify(subtaskState, 
times(2)).registerSharedStates(sharedStateRegistry);
+                                       verify(subtaskState, 
times(2)).registerSharedStates(any(SharedStateRegistry.class));
                                }
                        }
                }
@@ -3150,8 +3136,7 @@ public class CheckpointCoordinatorTest {
                        Executors.directExecutor());
 
                store.addCheckpoint(
-                       new CompletedCheckpoint(new JobID(), 0, 0, 0, 
Collections.<JobVertexID, TaskState>emptyMap()),
-                       coord.getSharedStateRegistry());
+                       new CompletedCheckpoint(new JobID(), 0, 0, 0, 
Collections.<JobVertexID, TaskState>emptyMap()));
 
                CheckpointStatsTracker tracker = 
mock(CheckpointStatsTracker.class);
                coord.setCheckpointStatsTracker(tracker);

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 9e372e1..2fc1de5 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -30,7 +30,6 @@ import 
org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -255,7 +254,7 @@ public class CheckpointStateRestoreTest {
                }
                CompletedCheckpoint checkpoint = new CompletedCheckpoint(new 
JobID(), 0, 1, 2, new HashMap<>(checkpointTaskStates));
 
-               coord.getCheckpointStore().addCheckpoint(checkpoint, 
coord.getSharedStateRegistry());
+               coord.getCheckpointStore().addCheckpoint(checkpoint);
 
                coord.restoreLatestCheckpointedState(tasks, true, false);
                coord.restoreLatestCheckpointedState(tasks, true, true);
@@ -273,7 +272,7 @@ public class CheckpointStateRestoreTest {
 
                checkpoint = new CompletedCheckpoint(new JobID(), 1, 2, 3, new 
HashMap<>(checkpointTaskStates));
 
-               coord.getCheckpointStore().addCheckpoint(checkpoint, 
coord.getSharedStateRegistry());
+               coord.getCheckpointStore().addCheckpoint(checkpoint);
 
                // (i) Allow non restored state (should succeed)
                coord.restoreLatestCheckpointedState(tasks, true, true);

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
index aa1726b..4a36dd2 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
@@ -37,6 +37,7 @@ import java.util.concurrent.CountDownLatch;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -68,7 +69,6 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
        @Test
        public void testAddAndGetLatestCheckpoint() throws Exception {
                CompletedCheckpointStore checkpoints = 
createCompletedCheckpoints(4);
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                
                // Empty state
                assertEquals(0, checkpoints.getNumberOfRetainedCheckpoints());
@@ -78,11 +78,11 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
                                createCheckpoint(0), createCheckpoint(1) };
 
                // Add and get latest
-               checkpoints.addCheckpoint(expected[0], sharedStateRegistry);
+               checkpoints.addCheckpoint(expected[0]);
                assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints());
                verifyCheckpoint(expected[0], 
checkpoints.getLatestCheckpoint());
 
-               checkpoints.addCheckpoint(expected[1], sharedStateRegistry);
+               checkpoints.addCheckpoint(expected[1]);
                assertEquals(2, checkpoints.getNumberOfRetainedCheckpoints());
                verifyCheckpoint(expected[1], 
checkpoints.getLatestCheckpoint());
        }
@@ -93,8 +93,7 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
         */
        @Test
        public void testAddCheckpointMoreThanMaxRetained() throws Exception {
-               CompletedCheckpointStore checkpoints = 
createCompletedCheckpoints(1);   
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+               CompletedCheckpointStore checkpoints = 
createCompletedCheckpoints(1);
 
                TestCompletedCheckpoint[] expected = new 
TestCompletedCheckpoint[] {
                                createCheckpoint(0), createCheckpoint(1),
@@ -102,13 +101,13 @@ public abstract class CompletedCheckpointStoreTest 
extends TestLogger {
                };
 
                // Add checkpoints
-               checkpoints.addCheckpoint(expected[0], sharedStateRegistry);
+               checkpoints.addCheckpoint(expected[0]);
                assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints());
 
                for (int i = 1; i < expected.length; i++) {
                        Collection<TaskState> taskStates = expected[i - 
1].getTaskStates().values();
 
-                       checkpoints.addCheckpoint(expected[i], 
sharedStateRegistry);
+                       checkpoints.addCheckpoint(expected[i]);
 
                        // The ZooKeeper implementation discards asynchronously
                        expected[i - 1].awaitDiscard();
@@ -117,7 +116,7 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
 
                        for (TaskState taskState : taskStates) {
                                for (SubtaskState subtaskState : 
taskState.getStates()) {
-                                       verify(subtaskState, 
times(1)).unregisterSharedStates(sharedStateRegistry);
+                                       verify(subtaskState, 
times(1)).unregisterSharedStates(any(SharedStateRegistry.class));
                                }
                        }
                }
@@ -146,7 +145,6 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
        @Test
        public void testGetAllCheckpoints() throws Exception {
                CompletedCheckpointStore checkpoints = 
createCompletedCheckpoints(4);
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
 
                TestCompletedCheckpoint[] expected = new 
TestCompletedCheckpoint[] {
                                createCheckpoint(0), createCheckpoint(1),
@@ -154,7 +152,7 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
                };
 
                for (TestCompletedCheckpoint checkpoint : expected) {
-                       checkpoints.addCheckpoint(checkpoint, 
sharedStateRegistry);
+                       checkpoints.addCheckpoint(checkpoint);
                }
 
                List<CompletedCheckpoint> actual = 
checkpoints.getAllCheckpoints();
@@ -172,7 +170,6 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
        @Test
        public void testDiscardAllCheckpoints() throws Exception {
                CompletedCheckpointStore checkpoints = 
createCompletedCheckpoints(4);
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
 
                TestCompletedCheckpoint[] expected = new 
TestCompletedCheckpoint[] {
                                createCheckpoint(0), createCheckpoint(1),
@@ -180,10 +177,10 @@ public abstract class CompletedCheckpointStoreTest 
extends TestLogger {
                };
 
                for (TestCompletedCheckpoint checkpoint : expected) {
-                       checkpoints.addCheckpoint(checkpoint, 
sharedStateRegistry);
+                       checkpoints.addCheckpoint(checkpoint);
                }
 
-               checkpoints.shutdown(JobStatus.FINISHED, sharedStateRegistry);
+               checkpoints.shutdown(JobStatus.FINISHED);
 
                // Empty state
                assertNull(checkpoints.getLatestCheckpoint());
@@ -235,10 +232,10 @@ public abstract class CompletedCheckpointStoreTest 
extends TestLogger {
                }
        }
 
-       protected void verifyCheckpointRegistered(Collection<TaskState> 
taskStates, SharedStateRegistry sharedStateRegistry) {
+       protected void verifyCheckpointRegistered(Collection<TaskState> 
taskStates, SharedStateRegistry registry) {
                for (TaskState taskState : taskStates) {
                        for (SubtaskState subtaskState : taskState.getStates()) 
{
-                               verify(subtaskState, 
times(1)).registerSharedStates(eq(sharedStateRegistry));
+                               verify(subtaskState, 
times(1)).registerSharedStates(eq(registry));
                        }
                }
        }
@@ -246,7 +243,6 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
        protected void verifyCheckpointDiscarded(Collection<TaskState> 
taskStates) {
                for (TaskState taskState : taskStates) {
                        for (SubtaskState subtaskState : taskState.getStates()) 
{
-                               verify(subtaskState, 
times(1)).discardSharedStatesOnFail();
                                verify(subtaskState, times(1)).discardState();
                        }
                }

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
index e7c1c3b..5fce62e 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
@@ -31,17 +31,14 @@ import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.jobmanager.scheduler.Scheduler;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.util.SerializedValue;
 
 import org.junit.Test;
-import org.mockito.Matchers;
 
 import java.net.URL;
 import java.util.Collections;
 
-import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
@@ -62,7 +59,7 @@ public class ExecutionGraphCheckpointCoordinatorTest {
                graph.fail(new Exception("Test Exception"));
 
                verify(counter, times(1)).shutdown(JobStatus.FAILED);
-               verify(store, times(1)).shutdown(eq(JobStatus.FAILED), 
any(SharedStateRegistry.class));
+               verify(store, times(1)).shutdown(eq(JobStatus.FAILED));
        }
 
        /**
@@ -79,7 +76,7 @@ public class ExecutionGraphCheckpointCoordinatorTest {
 
                // No shutdown
                verify(counter, times(1)).shutdown(eq(JobStatus.SUSPENDED));
-               verify(store, times(1)).shutdown(eq(JobStatus.SUSPENDED), 
any(SharedStateRegistry.class));
+               verify(store, times(1)).shutdown(eq(JobStatus.SUSPENDED));
        }
 
        private ExecutionGraph createExecutionGraphAndEnableCheckpointing(

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
index d77fac1..2dd1803 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
@@ -207,7 +207,6 @@ public class PendingCheckpointTest {
                // execute asynchronous discard operation
                executor.runQueuedCommands();
                verify(state, times(1)).discardState();
-               verify(state, times(1)).discardSharedStatesOnFail();
 
                // Abort error
                Mockito.reset(state);
@@ -219,7 +218,6 @@ public class PendingCheckpointTest {
                // execute asynchronous discard operation
                executor.runQueuedCommands();
                verify(state, times(1)).discardState();
-               verify(state, times(1)).discardSharedStatesOnFail();
 
                // Abort expired
                Mockito.reset(state);
@@ -231,7 +229,6 @@ public class PendingCheckpointTest {
                // execute asynchronous discard operation
                executor.runQueuedCommands();
                verify(state, times(1)).discardState();
-               verify(state, times(1)).discardSharedStatesOnFail();
 
                // Abort subsumed
                Mockito.reset(state);
@@ -243,7 +240,6 @@ public class PendingCheckpointTest {
                // execute asynchronous discard operation
                executor.runQueuedCommands();
                verify(state, times(1)).discardState();
-               verify(state, times(1)).discardSharedStatesOnFail();
        }
 
        /**

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
index 7a85897..64aeeba 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
@@ -30,6 +30,7 @@ import java.util.List;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.any;
 import static org.powermock.api.mockito.PowerMockito.doReturn;
 import static org.powermock.api.mockito.PowerMockito.doThrow;
 import static org.powermock.api.mockito.PowerMockito.mock;
@@ -40,7 +41,7 @@ import static org.powermock.api.mockito.PowerMockito.mock;
 public class StandaloneCompletedCheckpointStoreTest extends 
CompletedCheckpointStoreTest {
 
        @Override
-       protected CompletedCheckpointStore createCompletedCheckpoints(
+       protected AbstractCompletedCheckpointStore createCompletedCheckpoints(
                        int maxNumberOfCheckpointsToRetain) throws Exception {
 
                return new 
StandaloneCompletedCheckpointStore(maxNumberOfCheckpointsToRetain);
@@ -51,16 +52,15 @@ public class StandaloneCompletedCheckpointStoreTest extends 
CompletedCheckpointS
         */
        @Test
        public void testShutdownDiscardsCheckpoints() throws Exception {
-               CompletedCheckpointStore store = createCompletedCheckpoints(1);
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+               AbstractCompletedCheckpointStore store = 
createCompletedCheckpoints(1);
                TestCompletedCheckpoint checkpoint = createCheckpoint(0);
                Collection<TaskState> taskStates = 
checkpoint.getTaskStates().values();
 
-               store.addCheckpoint(checkpoint, sharedStateRegistry);
+               store.addCheckpoint(checkpoint);
                assertEquals(1, store.getNumberOfRetainedCheckpoints());
-               verifyCheckpointRegistered(taskStates, sharedStateRegistry);
+               verifyCheckpointRegistered(taskStates, 
store.sharedStateRegistry);
 
-               store.shutdown(JobStatus.FINISHED, sharedStateRegistry);
+               store.shutdown(JobStatus.FINISHED);
                assertEquals(0, store.getNumberOfRetainedCheckpoints());
                assertTrue(checkpoint.isDiscarded());
                verifyCheckpointDiscarded(taskStates);
@@ -72,16 +72,15 @@ public class StandaloneCompletedCheckpointStoreTest extends 
CompletedCheckpointS
         */
        @Test
        public void testSuspendDiscardsCheckpoints() throws Exception {
-               CompletedCheckpointStore store = createCompletedCheckpoints(1);
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+               AbstractCompletedCheckpointStore store = 
createCompletedCheckpoints(1);
                TestCompletedCheckpoint checkpoint = createCheckpoint(0);
                Collection<TaskState> taskStates = 
checkpoint.getTaskStates().values();
 
-               store.addCheckpoint(checkpoint, sharedStateRegistry);
+               store.addCheckpoint(checkpoint);
                assertEquals(1, store.getNumberOfRetainedCheckpoints());
-               verifyCheckpointRegistered(taskStates, sharedStateRegistry);
+               verifyCheckpointRegistered(taskStates, 
store.sharedStateRegistry);
 
-               store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry);
+               store.shutdown(JobStatus.SUSPENDED);
                assertEquals(0, store.getNumberOfRetainedCheckpoints());
                assertTrue(checkpoint.isDiscarded());
                verifyCheckpointDiscarded(taskStates);
@@ -96,16 +95,15 @@ public class StandaloneCompletedCheckpointStoreTest extends 
CompletedCheckpointS
                
                final int numCheckpointsToRetain = 1;
                CompletedCheckpointStore store = 
createCompletedCheckpoints(numCheckpointsToRetain);
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                
                for (long i = 0; i <= numCheckpointsToRetain; ++i) {
                        CompletedCheckpoint checkpointToAdd = 
mock(CompletedCheckpoint.class);
                        doReturn(i).when(checkpointToAdd).getCheckpointID();
                        
doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates();
-                       doThrow(new 
IOException()).when(checkpointToAdd).discardOnSubsume(sharedStateRegistry);
+                       doThrow(new 
IOException()).when(checkpointToAdd).discardOnSubsume(any(SharedStateRegistry.class));
                        
                        try {
-                               store.addCheckpoint(checkpointToAdd, 
sharedStateRegistry);
+                               store.addCheckpoint(checkpointToAdd);
                                
                                // The checkpoint should be in the store if we 
successfully add it into the store.
                                List<CompletedCheckpoint> addedCheckpoints = 
store.getAllCheckpoints();

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
index 607e773..73fcf78 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
@@ -22,7 +22,6 @@ import org.apache.curator.framework.CuratorFramework;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.state.RetrievableStateHandle;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
 import org.apache.flink.runtime.zookeeper.ZooKeeperTestEnvironment;
 import org.junit.AfterClass;
@@ -59,7 +58,7 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends 
CompletedCheckpoint
        }
 
        @Override
-       protected CompletedCheckpointStore createCompletedCheckpoints(
+       protected AbstractCompletedCheckpointStore createCompletedCheckpoints(
                        int maxNumberOfCheckpointsToRetain) throws Exception {
 
                return new 
ZooKeeperCompletedCheckpointStore(maxNumberOfCheckpointsToRetain,
@@ -80,21 +79,20 @@ public class ZooKeeperCompletedCheckpointStoreITCase 
extends CompletedCheckpoint
         */
        @Test
        public void testRecover() throws Exception {
-               CompletedCheckpointStore checkpoints = 
createCompletedCheckpoints(3);
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+               AbstractCompletedCheckpointStore checkpoints = 
createCompletedCheckpoints(3);
 
                TestCompletedCheckpoint[] expected = new 
TestCompletedCheckpoint[] {
                                createCheckpoint(0), createCheckpoint(1), 
createCheckpoint(2)
                };
 
                // Add multiple checkpoints
-               checkpoints.addCheckpoint(expected[0], sharedStateRegistry);
-               checkpoints.addCheckpoint(expected[1], sharedStateRegistry);
-               checkpoints.addCheckpoint(expected[2], sharedStateRegistry);
+               checkpoints.addCheckpoint(expected[0]);
+               checkpoints.addCheckpoint(expected[1]);
+               checkpoints.addCheckpoint(expected[2]);
 
-               
verifyCheckpointRegistered(expected[0].getTaskStates().values(), 
sharedStateRegistry);
-               
verifyCheckpointRegistered(expected[1].getTaskStates().values(), 
sharedStateRegistry);
-               
verifyCheckpointRegistered(expected[2].getTaskStates().values(), 
sharedStateRegistry);
+               
verifyCheckpointRegistered(expected[0].getTaskStates().values(), 
checkpoints.sharedStateRegistry);
+               
verifyCheckpointRegistered(expected[1].getTaskStates().values(), 
checkpoints.sharedStateRegistry);
+               
verifyCheckpointRegistered(expected[2].getTaskStates().values(), 
checkpoints.sharedStateRegistry);
 
                // All three should be in ZK
                assertEquals(3, 
ZooKeeper.getClient().getChildren().forPath(CheckpointsPath).size());
@@ -104,9 +102,8 @@ public class ZooKeeperCompletedCheckpointStoreITCase 
extends CompletedCheckpoint
                resetCheckpoint(expected[1].getTaskStates().values());
                resetCheckpoint(expected[2].getTaskStates().values());
 
-               // Recover
-               SharedStateRegistry newSharedStateRegistry = new 
SharedStateRegistry();
-               checkpoints.recover(newSharedStateRegistry);
+               // Recover TODO!!! clear registry!
+               checkpoints.recover();
 
                assertEquals(3, 
ZooKeeper.getClient().getChildren().forPath(CheckpointsPath).size());
                assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints());
@@ -117,14 +114,14 @@ public class ZooKeeperCompletedCheckpointStoreITCase 
extends CompletedCheckpoint
                expectedCheckpoints.add(expected[2]);
                expectedCheckpoints.add(createCheckpoint(3));
 
-               checkpoints.addCheckpoint(expectedCheckpoints.get(2), 
newSharedStateRegistry);
+               checkpoints.addCheckpoint(expectedCheckpoints.get(2));
 
                List<CompletedCheckpoint> actualCheckpoints = 
checkpoints.getAllCheckpoints();
 
                assertEquals(expectedCheckpoints, actualCheckpoints);
 
                for (CompletedCheckpoint actualCheckpoint : actualCheckpoints) {
-                       
verifyCheckpointRegistered(actualCheckpoint.getTaskStates().values(), 
newSharedStateRegistry);
+                       
verifyCheckpointRegistered(actualCheckpoint.getTaskStates().values(), 
checkpoints.sharedStateRegistry);
                }
        }
 
@@ -136,18 +133,17 @@ public class ZooKeeperCompletedCheckpointStoreITCase 
extends CompletedCheckpoint
                CuratorFramework client = ZooKeeper.getClient();
 
                CompletedCheckpointStore store = createCompletedCheckpoints(1);
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                TestCompletedCheckpoint checkpoint = createCheckpoint(0);
 
-               store.addCheckpoint(checkpoint, sharedStateRegistry);
+               store.addCheckpoint(checkpoint);
                assertEquals(1, store.getNumberOfRetainedCheckpoints());
                assertNotNull(client.checkExists().forPath(CheckpointsPath + 
"/" + checkpoint.getCheckpointID()));
 
-               store.shutdown(JobStatus.FINISHED, sharedStateRegistry);
+               store.shutdown(JobStatus.FINISHED);
                assertEquals(0, store.getNumberOfRetainedCheckpoints());
                assertNull(client.checkExists().forPath(CheckpointsPath + "/" + 
checkpoint.getCheckpointID()));
 
-               store.recover(sharedStateRegistry);
+               store.recover();
 
                assertEquals(0, store.getNumberOfRetainedCheckpoints());
        }
@@ -161,20 +157,19 @@ public class ZooKeeperCompletedCheckpointStoreITCase 
extends CompletedCheckpoint
                CuratorFramework client = ZooKeeper.getClient();
 
                CompletedCheckpointStore store = createCompletedCheckpoints(1);
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                TestCompletedCheckpoint checkpoint = createCheckpoint(0);
 
-               store.addCheckpoint(checkpoint, sharedStateRegistry);
+               store.addCheckpoint(checkpoint);
                assertEquals(1, store.getNumberOfRetainedCheckpoints());
                assertNotNull(client.checkExists().forPath(CheckpointsPath + 
"/" + checkpoint.getCheckpointID()));
 
-               store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry);
+               store.shutdown(JobStatus.SUSPENDED);
 
                assertEquals(0, store.getNumberOfRetainedCheckpoints());
                assertNotNull(client.checkExists().forPath(CheckpointsPath + 
"/" + checkpoint.getCheckpointID()));
 
                // Recover again
-               store.recover(sharedStateRegistry);
+               store.recover();
 
                CompletedCheckpoint recovered = store.getLatestCheckpoint();
                assertEquals(checkpoint, recovered);

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
index 1f5731d..66ef232 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
@@ -27,7 +27,6 @@ import org.apache.curator.utils.EnsurePath;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.state.RetrievableStateHandle;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
 import org.apache.flink.runtime.zookeeper.ZooKeeperStateHandleStore;
 import org.apache.flink.util.TestLogger;
@@ -160,9 +159,7 @@ public class ZooKeeperCompletedCheckpointStoreTest extends 
TestLogger {
                        stateSotrage,
                        Executors.directExecutor());
 
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
-
-               zooKeeperCompletedCheckpointStore.recover(sharedStateRegistry);
+               zooKeeperCompletedCheckpointStore.recover();
 
                CompletedCheckpoint latestCompletedCheckpoint = 
zooKeeperCompletedCheckpointStore.getLatestCheckpoint();
 
@@ -227,16 +224,13 @@ public class ZooKeeperCompletedCheckpointStoreTest 
extends TestLogger {
                        stateSotrage,
                        Executors.directExecutor());
 
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
-               
-               
                for (long i = 0; i <= numCheckpointsToRetain; ++i) {
                        CompletedCheckpoint checkpointToAdd = 
mock(CompletedCheckpoint.class);
                        doReturn(i).when(checkpointToAdd).getCheckpointID();
                        
doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates();
                        
                        try {
-                               
zooKeeperCompletedCheckpointStore.addCheckpoint(checkpointToAdd, 
sharedStateRegistry);
+                               
zooKeeperCompletedCheckpointStore.addCheckpoint(checkpointToAdd);
                                
                                // The checkpoint should be in the store if we 
successfully add it into the store.
                                List<CompletedCheckpoint> addedCheckpoints = 
zooKeeperCompletedCheckpointStore.getAllCheckpoints();

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
index cb14ff0..821bb69 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
@@ -34,50 +34,20 @@ public class SharedStateRegistryTest {
 
                // register one state
                TestSharedState firstState = new TestSharedState("first");
-               sharedStateRegistry.register(firstState, true);
-               assertEquals(1, 
sharedStateRegistry.getReferenceCount(firstState));
+               assertEquals(1, sharedStateRegistry.register(firstState));
 
                // register another state
                TestSharedState secondState = new TestSharedState("second");
-               sharedStateRegistry.register(secondState, true);
-               assertEquals(1, 
sharedStateRegistry.getReferenceCount(secondState));
+               assertEquals(1, sharedStateRegistry.register(secondState));
 
                // register the first state again
-               sharedStateRegistry.register(firstState, false);
-               assertEquals(2, 
sharedStateRegistry.getReferenceCount(firstState));
+               assertEquals(2, sharedStateRegistry.register(firstState));
 
                // unregister the second state
-               sharedStateRegistry.unregister(secondState);
-               assertEquals(0, 
sharedStateRegistry.getReferenceCount(secondState));
+               assertEquals(0, sharedStateRegistry.unregister(secondState));
 
                // unregister the first state
-               sharedStateRegistry.unregister(firstState);
-               assertEquals(1, 
sharedStateRegistry.getReferenceCount(firstState));
-       }
-
-       /**
-        * Validate that registering a handle referencing uncreated state will 
throw exception
-        */
-       @Test(expected = IllegalStateException.class)
-       public void testRegisterWithUncreatedReference() {
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
-
-               // register one state
-               TestSharedState state = new TestSharedState("state");
-               sharedStateRegistry.register(state, false);
-       }
-
-       /**
-        * Validate that registering duplicate creation of the same state will 
throw exception
-        */
-       @Test(expected = IllegalStateException.class)
-       public void testRegisterWithDuplicateState() {
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
-
-               // register one state
-               TestSharedState state = new TestSharedState("state");
-               sharedStateRegistry.register(state, true);
-               sharedStateRegistry.register(state, true);
+               assertEquals(1, sharedStateRegistry.unregister(firstState));
        }
 
        /**
@@ -100,7 +70,7 @@ public class SharedStateRegistryTest {
                }
 
                @Override
-               public String getKey() {
+               public String getRegistrationKey() {
                        return key;
                }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
index 75b0f6f..a932c18 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
@@ -18,12 +18,9 @@
 
 package org.apache.flink.runtime.testutils;
 
+import org.apache.flink.runtime.checkpoint.AbstractCompletedCheckpointStore;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
-import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore;
 import org.apache.flink.runtime.jobgraph.JobStatus;
-import org.apache.flink.runtime.state.SharedStateRegistry;
-import org.apache.flink.runtime.state.StateObject;
-import org.apache.flink.runtime.state.StateUtil;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -35,7 +32,7 @@ import java.util.List;
  * A checkpoint store, which supports shutdown and suspend. You can use this 
to test HA
  * as long as the factory always returns the same store instance.
  */
-public class RecoverableCompletedCheckpointStore implements 
CompletedCheckpointStore {
+public class RecoverableCompletedCheckpointStore extends 
AbstractCompletedCheckpointStore {
 
        private static final Logger LOG = 
LoggerFactory.getLogger(RecoverableCompletedCheckpointStore.class);
 
@@ -44,7 +41,7 @@ public class RecoverableCompletedCheckpointStore implements 
CompletedCheckpointS
        private final ArrayDeque<CompletedCheckpoint> suspended = new 
ArrayDeque<>(2);
 
        @Override
-       public void recover(SharedStateRegistry sharedStateRegistry) throws 
Exception {
+       public void recover() throws Exception {
                checkpoints.addAll(suspended);
                suspended.clear();
 
@@ -54,7 +51,7 @@ public class RecoverableCompletedCheckpointStore implements 
CompletedCheckpointS
        }
 
        @Override
-       public void addCheckpoint(CompletedCheckpoint checkpoint, 
SharedStateRegistry sharedStateRegistry) throws Exception {
+       public void addCheckpoint(CompletedCheckpoint checkpoint) throws 
Exception {
                checkpoints.addLast(checkpoint);
 
                checkpoint.registerSharedStates(sharedStateRegistry);
@@ -71,7 +68,7 @@ public class RecoverableCompletedCheckpointStore implements 
CompletedCheckpointS
        }
 
        @Override
-       public void shutdown(JobStatus jobStatus, SharedStateRegistry 
sharedStateRegistry) throws Exception {
+       public void shutdown(JobStatus jobStatus) throws Exception {
                if (jobStatus.isGloballyTerminalState()) {
                        checkpoints.clear();
                        suspended.clear();

Reply via email to