http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
index f77c755..aa1726b 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
@@ -21,13 +21,14 @@ package org.apache.flink.runtime.checkpoint;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.messages.CheckpointMessagesTest;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.util.TestLogger;
 import org.junit.Test;
+import org.mockito.Mockito;
 
 import java.io.IOException;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -36,6 +37,9 @@ import java.util.concurrent.CountDownLatch;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 
 /**
  * Test for basic {@link CompletedCheckpointStore} contract.
@@ -64,7 +68,8 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
        @Test
        public void testAddAndGetLatestCheckpoint() throws Exception {
                CompletedCheckpointStore checkpoints = 
createCompletedCheckpoints(4);
-
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+               
                // Empty state
                assertEquals(0, checkpoints.getNumberOfRetainedCheckpoints());
                assertEquals(0, checkpoints.getAllCheckpoints().size());
@@ -73,11 +78,11 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
                                createCheckpoint(0), createCheckpoint(1) };
 
                // Add and get latest
-               checkpoints.addCheckpoint(expected[0]);
+               checkpoints.addCheckpoint(expected[0], sharedStateRegistry);
                assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints());
                verifyCheckpoint(expected[0], 
checkpoints.getLatestCheckpoint());
 
-               checkpoints.addCheckpoint(expected[1]);
+               checkpoints.addCheckpoint(expected[1], sharedStateRegistry);
                assertEquals(2, checkpoints.getNumberOfRetainedCheckpoints());
                verifyCheckpoint(expected[1], 
checkpoints.getLatestCheckpoint());
        }
@@ -88,7 +93,8 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
         */
        @Test
        public void testAddCheckpointMoreThanMaxRetained() throws Exception {
-               CompletedCheckpointStore checkpoints = 
createCompletedCheckpoints(1);
+               CompletedCheckpointStore checkpoints = 
createCompletedCheckpoints(1);   
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
 
                TestCompletedCheckpoint[] expected = new 
TestCompletedCheckpoint[] {
                                createCheckpoint(0), createCheckpoint(1),
@@ -96,16 +102,24 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
                };
 
                // Add checkpoints
-               checkpoints.addCheckpoint(expected[0]);
+               checkpoints.addCheckpoint(expected[0], sharedStateRegistry);
                assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints());
 
                for (int i = 1; i < expected.length; i++) {
-                       checkpoints.addCheckpoint(expected[i]);
+                       Collection<TaskState> taskStates = expected[i - 
1].getTaskStates().values();
+
+                       checkpoints.addCheckpoint(expected[i], 
sharedStateRegistry);
 
                        // The ZooKeeper implementation discards asynchronously
                        expected[i - 1].awaitDiscard();
                        assertTrue(expected[i - 1].isDiscarded());
                        assertEquals(1, 
checkpoints.getNumberOfRetainedCheckpoints());
+
+                       for (TaskState taskState : taskStates) {
+                               for (SubtaskState subtaskState : 
taskState.getStates()) {
+                                       verify(subtaskState, 
times(1)).unregisterSharedStates(sharedStateRegistry);
+                               }
+                       }
                }
        }
 
@@ -132,6 +146,7 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
        @Test
        public void testGetAllCheckpoints() throws Exception {
                CompletedCheckpointStore checkpoints = 
createCompletedCheckpoints(4);
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
 
                TestCompletedCheckpoint[] expected = new 
TestCompletedCheckpoint[] {
                                createCheckpoint(0), createCheckpoint(1),
@@ -139,7 +154,7 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
                };
 
                for (TestCompletedCheckpoint checkpoint : expected) {
-                       checkpoints.addCheckpoint(checkpoint);
+                       checkpoints.addCheckpoint(checkpoint, 
sharedStateRegistry);
                }
 
                List<CompletedCheckpoint> actual = 
checkpoints.getAllCheckpoints();
@@ -157,6 +172,7 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
        @Test
        public void testDiscardAllCheckpoints() throws Exception {
                CompletedCheckpointStore checkpoints = 
createCompletedCheckpoints(4);
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
 
                TestCompletedCheckpoint[] expected = new 
TestCompletedCheckpoint[] {
                                createCheckpoint(0), createCheckpoint(1),
@@ -164,10 +180,10 @@ public abstract class CompletedCheckpointStoreTest 
extends TestLogger {
                };
 
                for (TestCompletedCheckpoint checkpoint : expected) {
-                       checkpoints.addCheckpoint(checkpoint);
+                       checkpoints.addCheckpoint(checkpoint, 
sharedStateRegistry);
                }
 
-               checkpoints.shutdown(JobStatus.FINISHED);
+               checkpoints.shutdown(JobStatus.FINISHED, sharedStateRegistry);
 
                // Empty state
                assertNull(checkpoints.getLatestCheckpoint());
@@ -203,15 +219,39 @@ public abstract class CompletedCheckpointStoreTest 
extends TestLogger {
                taskGroupStates.put(jvid, taskState);
 
                for (int i = 0; i < numberOfStates; i++) {
-                       ChainedStateHandle<StreamStateHandle> stateHandle = 
CheckpointCoordinatorTest.generateChainedStateHandle(
-                                       new CheckpointMessagesTest.MyHandle());
+                       SubtaskState subtaskState = 
CheckpointCoordinatorTest.mockSubtaskState(jvid, i, new KeyGroupRange(i, i));
 
-                       taskState.putState(i, new SubtaskState(stateHandle, 
null, null, null, null));
+                       taskState.putState(i, subtaskState);
                }
 
                return new TestCompletedCheckpoint(new JobID(), id, 0, 
taskGroupStates, props);
        }
 
+       protected void resetCheckpoint(Collection<TaskState> taskStates) {
+               for (TaskState taskState : taskStates) {
+                       for (SubtaskState subtaskState : taskState.getStates()) 
{
+                               Mockito.reset(subtaskState);
+                       }
+               }
+       }
+
+       protected void verifyCheckpointRegistered(Collection<TaskState> 
taskStates, SharedStateRegistry sharedStateRegistry) {
+               for (TaskState taskState : taskStates) {
+                       for (SubtaskState subtaskState : taskState.getStates()) 
{
+                               verify(subtaskState, 
times(1)).registerSharedStates(eq(sharedStateRegistry));
+                       }
+               }
+       }
+
+       protected void verifyCheckpointDiscarded(Collection<TaskState> 
taskStates) {
+               for (TaskState taskState : taskStates) {
+                       for (SubtaskState subtaskState : taskState.getStates()) 
{
+                               verify(subtaskState, 
times(1)).discardSharedStatesOnFail();
+                               verify(subtaskState, times(1)).discardState();
+                       }
+               }
+       }
+
        private void verifyCheckpoint(CompletedCheckpoint expected, 
CompletedCheckpoint actual) {
                assertEquals(expected, actual);
        }
@@ -241,8 +281,8 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
                }
 
                @Override
-               public boolean subsume() throws Exception {
-                       if (super.subsume()) {
+               public boolean discardOnSubsume(SharedStateRegistry 
sharedStateRegistry) throws Exception {
+                       if (super.discardOnSubsume(sharedStateRegistry)) {
                                discard();
                                return true;
                        } else {
@@ -251,8 +291,8 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
                }
 
                @Override
-               public boolean discard(JobStatus jobStatus) throws Exception {
-                       if (super.discard(jobStatus)) {
+               public boolean discardOnShutdown(JobStatus jobStatus, 
SharedStateRegistry sharedStateRegistry) throws Exception {
+                       if (super.discardOnShutdown(jobStatus, 
sharedStateRegistry)) {
                                discard();
                                return true;
                        } else {

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java
index b34e9a6..0b759d4 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java
@@ -23,6 +23,8 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.SharedStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.junit.Rule;
 import org.junit.Test;
@@ -30,10 +32,12 @@ import org.junit.rules.TemporaryFolder;
 import org.mockito.Mockito;
 
 import java.io.File;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
 import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -61,7 +65,7 @@ public class CompletedCheckpointTest {
                                new FileStateHandle(new Path(file.toURI()), 
file.length()),
                                file.getAbsolutePath());
 
-               checkpoint.discard(JobStatus.FAILED);
+               checkpoint.discardOnShutdown(JobStatus.FAILED, new 
SharedStateRegistry());
 
                assertEquals(false, file.exists());
        }
@@ -80,10 +84,15 @@ public class CompletedCheckpointTest {
                CompletedCheckpoint checkpoint = new CompletedCheckpoint(
                                new JobID(), 0, 0, 1, taskStates, props);
 
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+               checkpoint.registerSharedStates(sharedStateRegistry);
+               verify(state, 
times(1)).registerSharedStates(sharedStateRegistry);
+
                // Subsume
-               checkpoint.subsume();
+               checkpoint.discardOnSubsume(sharedStateRegistry);
 
                verify(state, times(1)).discardState();
+               verify(state, 
times(1)).unregisterSharedStates(sharedStateRegistry);
        }
 
        /**
@@ -112,17 +121,22 @@ public class CompletedCheckpointTest {
                                        new FileStateHandle(new 
Path(file.toURI()), file.length()),
                                        externalPath);
 
-                       checkpoint.discard(status);
+                       SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+                       checkpoint.registerSharedStates(sharedStateRegistry);
+
+                       checkpoint.discardOnShutdown(status, 
sharedStateRegistry);
                        verify(state, times(0)).discardState();
                        assertEquals(true, file.exists());
+                       verify(state, 
times(0)).unregisterSharedStates(sharedStateRegistry);
 
                        // Discard
                        props = new CheckpointProperties(false, false, true, 
true, true, true, true);
                        checkpoint = new CompletedCheckpoint(
                                        new JobID(), 0, 0, 1, new 
HashMap<>(taskStates), props);
 
-                       checkpoint.discard(status);
+                       checkpoint.discardOnShutdown(status, 
sharedStateRegistry);
                        verify(state, times(1)).discardState();
+                       verify(state, 
times(1)).unregisterSharedStates(sharedStateRegistry);
                }
        }
 
@@ -146,7 +160,7 @@ public class CompletedCheckpointTest {
                CompletedCheckpointStats.DiscardCallback callback = 
mock(CompletedCheckpointStats.DiscardCallback.class);
                completed.setDiscardCallback(callback);
 
-               completed.discard(JobStatus.FINISHED);
+               completed.discardOnShutdown(JobStatus.FINISHED, new 
SharedStateRegistry());
                verify(callback, times(1)).notifyDiscardedCheckpoint();
        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
index 0ab031e..e7c1c3b 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
@@ -31,6 +31,7 @@ import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.jobmanager.scheduler.Scheduler;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.util.SerializedValue;
 
@@ -40,6 +41,8 @@ import org.mockito.Matchers;
 import java.net.URL;
 import java.util.Collections;
 
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -59,7 +62,7 @@ public class ExecutionGraphCheckpointCoordinatorTest {
                graph.fail(new Exception("Test Exception"));
 
                verify(counter, times(1)).shutdown(JobStatus.FAILED);
-               verify(store, times(1)).shutdown(JobStatus.FAILED);
+               verify(store, times(1)).shutdown(eq(JobStatus.FAILED), 
any(SharedStateRegistry.class));
        }
 
        /**
@@ -75,8 +78,8 @@ public class ExecutionGraphCheckpointCoordinatorTest {
                graph.suspend(new Exception("Test Exception"));
 
                // No shutdown
-               verify(counter, 
times(1)).shutdown(Matchers.eq(JobStatus.SUSPENDED));
-               verify(store, 
times(1)).shutdown(Matchers.eq(JobStatus.SUSPENDED));
+               verify(counter, times(1)).shutdown(eq(JobStatus.SUSPENDED));
+               verify(store, times(1)).shutdown(eq(JobStatus.SUSPENDED), 
any(SharedStateRegistry.class));
        }
 
        private ExecutionGraph createExecutionGraphAndEnableCheckpointing(

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
index a15684c..d77fac1 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
@@ -24,15 +24,19 @@ import org.apache.flink.runtime.concurrent.Future;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.SharedStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
+import org.mockito.Mock;
 import org.mockito.Mockito;
 
 import java.io.File;
 import java.lang.reflect.Field;
 import java.util.ArrayDeque;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Queue;
@@ -45,7 +49,10 @@ import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyLong;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.powermock.api.mockito.PowerMockito.when;
@@ -184,9 +191,12 @@ public class PendingCheckpointTest {
        @SuppressWarnings("unchecked")
        public void testAbortDiscardsState() throws Exception {
                CheckpointProperties props = new CheckpointProperties(false, 
true, false, false, false, false, false);
-               TaskState state = mock(TaskState.class);
                QueueExecutor executor = new QueueExecutor();
 
+               TaskState state = mock(TaskState.class);
+               
doNothing().when(state).registerSharedStates(any(SharedStateRegistry.class));
+               
doNothing().when(state).unregisterSharedStates(any(SharedStateRegistry.class));
+
                String targetDir = tmpFolder.newFolder().getAbsolutePath();
 
                // Abort declined
@@ -197,6 +207,7 @@ public class PendingCheckpointTest {
                // execute asynchronous discard operation
                executor.runQueuedCommands();
                verify(state, times(1)).discardState();
+               verify(state, times(1)).discardSharedStatesOnFail();
 
                // Abort error
                Mockito.reset(state);
@@ -208,6 +219,7 @@ public class PendingCheckpointTest {
                // execute asynchronous discard operation
                executor.runQueuedCommands();
                verify(state, times(1)).discardState();
+               verify(state, times(1)).discardSharedStatesOnFail();
 
                // Abort expired
                Mockito.reset(state);
@@ -219,6 +231,7 @@ public class PendingCheckpointTest {
                // execute asynchronous discard operation
                executor.runQueuedCommands();
                verify(state, times(1)).discardState();
+               verify(state, times(1)).discardSharedStatesOnFail();
 
                // Abort subsumed
                Mockito.reset(state);
@@ -230,6 +243,7 @@ public class PendingCheckpointTest {
                // execute asynchronous discard operation
                executor.runQueuedCommands();
                verify(state, times(1)).discardState();
+               verify(state, times(1)).discardSharedStatesOnFail();
        }
 
        /**
@@ -340,7 +354,11 @@ public class PendingCheckpointTest {
                return createPendingCheckpoint(props, targetDirectory, 
Executors.directExecutor());
        }
 
-       private static PendingCheckpoint 
createPendingCheckpoint(CheckpointProperties props, String targetDirectory, 
Executor executor) {
+       private static PendingCheckpoint createPendingCheckpoint(
+                       CheckpointProperties props,
+                       String targetDirectory,
+                       Executor executor) {
+
                Map<ExecutionAttemptID, ExecutionVertex> ackTasks = new 
HashMap<>(ACK_TASKS);
                return new PendingCheckpoint(
                        new JobID(),

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
index cc7b2d0..7a85897 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
@@ -19,9 +19,12 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.jobgraph.JobStatus;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.junit.Test;
 
 import java.io.IOException;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 
 import static org.junit.Assert.assertEquals;
@@ -49,15 +52,18 @@ public class StandaloneCompletedCheckpointStoreTest extends 
CompletedCheckpointS
        @Test
        public void testShutdownDiscardsCheckpoints() throws Exception {
                CompletedCheckpointStore store = createCompletedCheckpoints(1);
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                TestCompletedCheckpoint checkpoint = createCheckpoint(0);
+               Collection<TaskState> taskStates = 
checkpoint.getTaskStates().values();
 
-               store.addCheckpoint(checkpoint);
+               store.addCheckpoint(checkpoint, sharedStateRegistry);
                assertEquals(1, store.getNumberOfRetainedCheckpoints());
+               verifyCheckpointRegistered(taskStates, sharedStateRegistry);
 
-               store.shutdown(JobStatus.FINISHED);
-
+               store.shutdown(JobStatus.FINISHED, sharedStateRegistry);
                assertEquals(0, store.getNumberOfRetainedCheckpoints());
                assertTrue(checkpoint.isDiscarded());
+               verifyCheckpointDiscarded(taskStates);
        }
 
        /**
@@ -67,15 +73,18 @@ public class StandaloneCompletedCheckpointStoreTest extends 
CompletedCheckpointS
        @Test
        public void testSuspendDiscardsCheckpoints() throws Exception {
                CompletedCheckpointStore store = createCompletedCheckpoints(1);
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                TestCompletedCheckpoint checkpoint = createCheckpoint(0);
+               Collection<TaskState> taskStates = 
checkpoint.getTaskStates().values();
 
-               store.addCheckpoint(checkpoint);
+               store.addCheckpoint(checkpoint, sharedStateRegistry);
                assertEquals(1, store.getNumberOfRetainedCheckpoints());
+               verifyCheckpointRegistered(taskStates, sharedStateRegistry);
 
-               store.shutdown(JobStatus.SUSPENDED);
-
+               store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry);
                assertEquals(0, store.getNumberOfRetainedCheckpoints());
                assertTrue(checkpoint.isDiscarded());
+               verifyCheckpointDiscarded(taskStates);
        }
        
        /**
@@ -87,14 +96,16 @@ public class StandaloneCompletedCheckpointStoreTest extends 
CompletedCheckpointS
                
                final int numCheckpointsToRetain = 1;
                CompletedCheckpointStore store = 
createCompletedCheckpoints(numCheckpointsToRetain);
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                
                for (long i = 0; i <= numCheckpointsToRetain; ++i) {
                        CompletedCheckpoint checkpointToAdd = 
mock(CompletedCheckpoint.class);
                        doReturn(i).when(checkpointToAdd).getCheckpointID();
-                       doThrow(new 
IOException()).when(checkpointToAdd).subsume();
+                       
doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates();
+                       doThrow(new 
IOException()).when(checkpointToAdd).discardOnSubsume(sharedStateRegistry);
                        
                        try {
-                               store.addCheckpoint(checkpointToAdd);
+                               store.addCheckpoint(checkpointToAdd, 
sharedStateRegistry);
                                
                                // The checkpoint should be in the store if we 
successfully add it into the store.
                                List<CompletedCheckpoint> addedCheckpoints = 
store.getAllCheckpoints();

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
index 625999a..607e773 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
@@ -22,6 +22,7 @@ import org.apache.curator.framework.CuratorFramework;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.state.RetrievableStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
 import org.apache.flink.runtime.zookeeper.ZooKeeperTestEnvironment;
 import org.junit.AfterClass;
@@ -80,22 +81,32 @@ public class ZooKeeperCompletedCheckpointStoreITCase 
extends CompletedCheckpoint
        @Test
        public void testRecover() throws Exception {
                CompletedCheckpointStore checkpoints = 
createCompletedCheckpoints(3);
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
 
                TestCompletedCheckpoint[] expected = new 
TestCompletedCheckpoint[] {
                                createCheckpoint(0), createCheckpoint(1), 
createCheckpoint(2)
                };
 
                // Add multiple checkpoints
-               checkpoints.addCheckpoint(expected[0]);
-               checkpoints.addCheckpoint(expected[1]);
-               checkpoints.addCheckpoint(expected[2]);
+               checkpoints.addCheckpoint(expected[0], sharedStateRegistry);
+               checkpoints.addCheckpoint(expected[1], sharedStateRegistry);
+               checkpoints.addCheckpoint(expected[2], sharedStateRegistry);
+
+               
verifyCheckpointRegistered(expected[0].getTaskStates().values(), 
sharedStateRegistry);
+               
verifyCheckpointRegistered(expected[1].getTaskStates().values(), 
sharedStateRegistry);
+               
verifyCheckpointRegistered(expected[2].getTaskStates().values(), 
sharedStateRegistry);
 
                // All three should be in ZK
                assertEquals(3, 
ZooKeeper.getClient().getChildren().forPath(CheckpointsPath).size());
                assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints());
 
+               resetCheckpoint(expected[0].getTaskStates().values());
+               resetCheckpoint(expected[1].getTaskStates().values());
+               resetCheckpoint(expected[2].getTaskStates().values());
+
                // Recover
-               checkpoints.recover();
+               SharedStateRegistry newSharedStateRegistry = new 
SharedStateRegistry();
+               checkpoints.recover(newSharedStateRegistry);
 
                assertEquals(3, 
ZooKeeper.getClient().getChildren().forPath(CheckpointsPath).size());
                assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints());
@@ -106,11 +117,15 @@ public class ZooKeeperCompletedCheckpointStoreITCase 
extends CompletedCheckpoint
                expectedCheckpoints.add(expected[2]);
                expectedCheckpoints.add(createCheckpoint(3));
 
-               checkpoints.addCheckpoint(expectedCheckpoints.get(2));
+               checkpoints.addCheckpoint(expectedCheckpoints.get(2), 
newSharedStateRegistry);
 
                List<CompletedCheckpoint> actualCheckpoints = 
checkpoints.getAllCheckpoints();
 
                assertEquals(expectedCheckpoints, actualCheckpoints);
+
+               for (CompletedCheckpoint actualCheckpoint : actualCheckpoints) {
+                       
verifyCheckpointRegistered(actualCheckpoint.getTaskStates().values(), 
newSharedStateRegistry);
+               }
        }
 
        /**
@@ -121,18 +136,18 @@ public class ZooKeeperCompletedCheckpointStoreITCase 
extends CompletedCheckpoint
                CuratorFramework client = ZooKeeper.getClient();
 
                CompletedCheckpointStore store = createCompletedCheckpoints(1);
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                TestCompletedCheckpoint checkpoint = createCheckpoint(0);
 
-               store.addCheckpoint(checkpoint);
+               store.addCheckpoint(checkpoint, sharedStateRegistry);
                assertEquals(1, store.getNumberOfRetainedCheckpoints());
                assertNotNull(client.checkExists().forPath(CheckpointsPath + 
"/" + checkpoint.getCheckpointID()));
 
-               store.shutdown(JobStatus.FINISHED);
-
+               store.shutdown(JobStatus.FINISHED, sharedStateRegistry);
                assertEquals(0, store.getNumberOfRetainedCheckpoints());
                assertNull(client.checkExists().forPath(CheckpointsPath + "/" + 
checkpoint.getCheckpointID()));
 
-               store.recover();
+               store.recover(sharedStateRegistry);
 
                assertEquals(0, store.getNumberOfRetainedCheckpoints());
        }
@@ -146,19 +161,20 @@ public class ZooKeeperCompletedCheckpointStoreITCase 
extends CompletedCheckpoint
                CuratorFramework client = ZooKeeper.getClient();
 
                CompletedCheckpointStore store = createCompletedCheckpoints(1);
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                TestCompletedCheckpoint checkpoint = createCheckpoint(0);
 
-               store.addCheckpoint(checkpoint);
+               store.addCheckpoint(checkpoint, sharedStateRegistry);
                assertEquals(1, store.getNumberOfRetainedCheckpoints());
                assertNotNull(client.checkExists().forPath(CheckpointsPath + 
"/" + checkpoint.getCheckpointID()));
 
-               store.shutdown(JobStatus.SUSPENDED);
+               store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry);
 
                assertEquals(0, store.getNumberOfRetainedCheckpoints());
                assertNotNull(client.checkExists().forPath(CheckpointsPath + 
"/" + checkpoint.getCheckpointID()));
 
                // Recover again
-               store.recover();
+               store.recover(sharedStateRegistry);
 
                CompletedCheckpoint recovered = store.getLatestCheckpoint();
                assertEquals(checkpoint, recovered);

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
index aa2ec85..1f5731d 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
@@ -27,6 +27,7 @@ import org.apache.curator.utils.EnsurePath;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.state.RetrievableStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
 import org.apache.flink.runtime.zookeeper.ZooKeeperStateHandleStore;
 import org.apache.flink.util.TestLogger;
@@ -40,6 +41,7 @@ import org.powermock.modules.junit4.PowerMockRunner;
 
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.concurrent.Executor;
@@ -158,7 +160,9 @@ public class ZooKeeperCompletedCheckpointStoreTest extends 
TestLogger {
                        stateSotrage,
                        Executors.directExecutor());
 
-               zooKeeperCompletedCheckpointStore.recover();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+
+               zooKeeperCompletedCheckpointStore.recover(sharedStateRegistry);
 
                CompletedCheckpoint latestCompletedCheckpoint = 
zooKeeperCompletedCheckpointStore.getLatestCheckpoint();
 
@@ -222,14 +226,17 @@ public class ZooKeeperCompletedCheckpointStoreTest 
extends TestLogger {
                        checkpointsPath,
                        stateSotrage,
                        Executors.directExecutor());
+
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
                
                
                for (long i = 0; i <= numCheckpointsToRetain; ++i) {
                        CompletedCheckpoint checkpointToAdd = 
mock(CompletedCheckpoint.class);
                        doReturn(i).when(checkpointToAdd).getCheckpointID();
+                       
doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates();
                        
                        try {
-                               
zooKeeperCompletedCheckpointStore.addCheckpoint(checkpointToAdd);
+                               
zooKeeperCompletedCheckpointStore.addCheckpoint(checkpointToAdd, 
sharedStateRegistry);
                                
                                // The checkpoint should be in the store if we 
successfully add it into the store.
                                List<CompletedCheckpoint> addedCheckpoints = 
zooKeeperCompletedCheckpointStore.getAllCheckpoints();

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
index 6eacaac..77eb566 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
@@ -71,6 +71,7 @@ import 
org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService;
 import org.apache.flink.runtime.messages.JobManagerMessages;
 import org.apache.flink.runtime.metrics.MetricRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
@@ -81,6 +82,7 @@ import org.apache.flink.runtime.testingUtils.TestingMessages;
 import org.apache.flink.runtime.testingUtils.TestingTaskManager;
 import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
+import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
 import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
 import org.apache.flink.util.InstantiationUtil;
 
@@ -164,7 +166,7 @@ public class JobManagerHARecoveryTest {
                        Scheduler scheduler = new 
Scheduler(TestingUtils.defaultExecutionContext());
 
                        MySubmittedJobGraphStore mySubmittedJobGraphStore = new 
MySubmittedJobGraphStore();
-                       MyCheckpointStore checkpointStore = new 
MyCheckpointStore();
+                       CompletedCheckpointStore checkpointStore = new 
RecoverableCompletedCheckpointStore();
                        CheckpointIDCounter checkpointCounter = new 
StandaloneCheckpointIDCounter();
                        CheckpointRecoveryFactory checkpointStateFactory = new 
MyCheckpointRecoveryFactory(checkpointStore, checkpointCounter);
                        TestingLeaderElectionService myLeaderElectionService = 
new TestingLeaderElectionService();
@@ -438,67 +440,6 @@ public class JobManagerHARecoveryTest {
                }
        }
 
-       /**
-        * A checkpoint store, which supports shutdown and suspend. You can use 
this to test HA
-        * as long as the factory always returns the same store instance.
-        */
-       static class MyCheckpointStore implements CompletedCheckpointStore {
-
-               private final ArrayDeque<CompletedCheckpoint> checkpoints = new 
ArrayDeque<>(2);
-
-               private final ArrayDeque<CompletedCheckpoint> suspended = new 
ArrayDeque<>(2);
-
-               @Override
-               public void recover() throws Exception {
-                       checkpoints.addAll(suspended);
-                       suspended.clear();
-               }
-
-               @Override
-               public void addCheckpoint(CompletedCheckpoint checkpoint) 
throws Exception {
-                       checkpoints.addLast(checkpoint);
-                       if (checkpoints.size() > 1) {
-                               checkpoints.removeFirst().subsume();
-                       }
-               }
-
-               @Override
-               public CompletedCheckpoint getLatestCheckpoint() throws 
Exception {
-                       return checkpoints.isEmpty() ? null : 
checkpoints.getLast();
-               }
-
-               @Override
-               public void shutdown(JobStatus jobStatus) throws Exception {
-                       if (jobStatus.isGloballyTerminalState()) {
-                               checkpoints.clear();
-                               suspended.clear();
-                       } else {
-                               suspended.addAll(checkpoints);
-                               checkpoints.clear();
-                       }
-               }
-
-               @Override
-               public List<CompletedCheckpoint> getAllCheckpoints() throws 
Exception {
-                       return new ArrayList<>(checkpoints);
-               }
-
-               @Override
-               public int getNumberOfRetainedCheckpoints() {
-                       return checkpoints.size();
-               }
-
-               @Override
-               public int getMaxNumberOfRetainedCheckpoints() {
-                       return 1;
-               }
-
-               @Override
-               public boolean requiresExternalizedCheckpoints() {
-                       return false;
-               }
-       }
-
        static class MyCheckpointRecoveryFactory implements 
CheckpointRecoveryFactory {
 
                private final CompletedCheckpointStore store;

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
new file mode 100644
index 0000000..cb14ff0
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.flink.runtime.state;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+public class SharedStateRegistryTest {
+
+       /**
+        * Validate that all states can be correctly registered at the registry.
+        */
+       @Test
+       public void testRegistryNormal() {
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+
+               // register one state
+               TestSharedState firstState = new TestSharedState("first");
+               sharedStateRegistry.register(firstState, true);
+               assertEquals(1, 
sharedStateRegistry.getReferenceCount(firstState));
+
+               // register another state
+               TestSharedState secondState = new TestSharedState("second");
+               sharedStateRegistry.register(secondState, true);
+               assertEquals(1, 
sharedStateRegistry.getReferenceCount(secondState));
+
+               // register the first state again
+               sharedStateRegistry.register(firstState, false);
+               assertEquals(2, 
sharedStateRegistry.getReferenceCount(firstState));
+
+               // unregister the second state
+               sharedStateRegistry.unregister(secondState);
+               assertEquals(0, 
sharedStateRegistry.getReferenceCount(secondState));
+
+               // unregister the first state
+               sharedStateRegistry.unregister(firstState);
+               assertEquals(1, 
sharedStateRegistry.getReferenceCount(firstState));
+       }
+
+       /**
+        * Validate that registering a handle referencing uncreated state will 
throw exception
+        */
+       @Test(expected = IllegalStateException.class)
+       public void testRegisterWithUncreatedReference() {
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+
+               // register one state
+               TestSharedState state = new TestSharedState("state");
+               sharedStateRegistry.register(state, false);
+       }
+
+       /**
+        * Validate that registering duplicate creation of the same state will 
throw exception
+        */
+       @Test(expected = IllegalStateException.class)
+       public void testRegisterWithDuplicateState() {
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+
+               // register one state
+               TestSharedState state = new TestSharedState("state");
+               sharedStateRegistry.register(state, true);
+               sharedStateRegistry.register(state, true);
+       }
+
+       /**
+        * Validate that unregister an unexisted key will throw exception
+        */
+       @Test(expected = IllegalStateException.class)
+       public void testUnregisterWithUnexistedKey() {
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+
+               sharedStateRegistry.unregister(new 
TestSharedState("unexisted"));
+       }
+
+       private static class TestSharedState implements SharedStateHandle {
+               private static final long serialVersionUID = 
4468635881465159780L;
+
+               private String key;
+
+               TestSharedState(String key) {
+                       this.key = key;
+               }
+
+               @Override
+               public String getKey() {
+                       return key;
+               }
+
+               @Override
+               public void discardState() throws Exception {
+                       // nothing to do
+               }
+
+               @Override
+               public long getStateSize() {
+                       return key.length();
+               }
+
+               @Override
+               public boolean equals(Object o) {
+                       if (this == o) {
+                               return true;
+                       }
+                       if (o == null || getClass() != o.getClass()) {
+                               return false;
+                       }
+
+                       TestSharedState testState = (TestSharedState) o;
+
+                       return key.equals(testState.key);
+               }
+
+               @Override
+               public int hashCode() {
+                       return key.hashCode();
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
new file mode 100644
index 0000000..75b0f6f
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.testutils;
+
+import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
+import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore;
+import org.apache.flink.runtime.jobgraph.JobStatus;
+import org.apache.flink.runtime.state.SharedStateRegistry;
+import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.runtime.state.StateUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * A checkpoint store, which supports shutdown and suspend. You can use this 
to test HA
+ * as long as the factory always returns the same store instance.
+ */
+public class RecoverableCompletedCheckpointStore implements 
CompletedCheckpointStore {
+
+       private static final Logger LOG = 
LoggerFactory.getLogger(RecoverableCompletedCheckpointStore.class);
+
+       private final ArrayDeque<CompletedCheckpoint> checkpoints = new 
ArrayDeque<>(2);
+
+       private final ArrayDeque<CompletedCheckpoint> suspended = new 
ArrayDeque<>(2);
+
+       @Override
+       public void recover(SharedStateRegistry sharedStateRegistry) throws 
Exception {
+               checkpoints.addAll(suspended);
+               suspended.clear();
+
+               for (CompletedCheckpoint checkpoint : checkpoints) {
+                       checkpoint.registerSharedStates(sharedStateRegistry);
+               }
+       }
+
+       @Override
+       public void addCheckpoint(CompletedCheckpoint checkpoint, 
SharedStateRegistry sharedStateRegistry) throws Exception {
+               checkpoints.addLast(checkpoint);
+
+               checkpoint.registerSharedStates(sharedStateRegistry);
+
+               if (checkpoints.size() > 1) {
+                       CompletedCheckpoint checkpointToSubsume = 
checkpoints.removeFirst();
+                       
checkpointToSubsume.discardOnSubsume(sharedStateRegistry);
+               }
+       }
+
+       @Override
+       public CompletedCheckpoint getLatestCheckpoint() throws Exception {
+               return checkpoints.isEmpty() ? null : checkpoints.getLast();
+       }
+
+       @Override
+       public void shutdown(JobStatus jobStatus, SharedStateRegistry 
sharedStateRegistry) throws Exception {
+               if (jobStatus.isGloballyTerminalState()) {
+                       checkpoints.clear();
+                       suspended.clear();
+               } else {
+                       suspended.clear();
+
+                       for (CompletedCheckpoint checkpoint : checkpoints) {
+                               
sharedStateRegistry.unregisterAll(checkpoint.getTaskStates().values());
+                               suspended.add(checkpoint);
+                       }
+
+                       checkpoints.clear();
+               }
+       }
+
+       @Override
+       public List<CompletedCheckpoint> getAllCheckpoints() throws Exception {
+               return new ArrayList<>(checkpoints);
+       }
+
+       @Override
+       public int getNumberOfRetainedCheckpoints() {
+               return checkpoints.size();
+       }
+
+       @Override
+       public int getMaxNumberOfRetainedCheckpoints() {
+               return 1;
+       }
+
+       @Override
+       public boolean requiresExternalizedCheckpoints() {
+               return false;
+       }
+}

Reply via email to