http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java index f77c755..aa1726b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java @@ -21,13 +21,14 @@ package org.apache.flink.runtime.checkpoint; import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.messages.CheckpointMessagesTest; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.util.TestLogger; import org.junit.Test; +import org.mockito.Mockito; import java.io.IOException; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -36,6 +37,9 @@ import java.util.concurrent.CountDownLatch; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; /** * Test for basic {@link CompletedCheckpointStore} contract. @@ -64,7 +68,8 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { @Test public void testAddAndGetLatestCheckpoint() throws Exception { CompletedCheckpointStore checkpoints = createCompletedCheckpoints(4); - + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + // Empty state assertEquals(0, checkpoints.getNumberOfRetainedCheckpoints()); assertEquals(0, checkpoints.getAllCheckpoints().size()); @@ -73,11 +78,11 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { createCheckpoint(0), createCheckpoint(1) }; // Add and get latest - checkpoints.addCheckpoint(expected[0]); + checkpoints.addCheckpoint(expected[0], sharedStateRegistry); assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints()); verifyCheckpoint(expected[0], checkpoints.getLatestCheckpoint()); - checkpoints.addCheckpoint(expected[1]); + checkpoints.addCheckpoint(expected[1], sharedStateRegistry); assertEquals(2, checkpoints.getNumberOfRetainedCheckpoints()); verifyCheckpoint(expected[1], checkpoints.getLatestCheckpoint()); } @@ -88,7 +93,8 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { */ @Test public void testAddCheckpointMoreThanMaxRetained() throws Exception { - CompletedCheckpointStore checkpoints = createCompletedCheckpoints(1); + CompletedCheckpointStore checkpoints = createCompletedCheckpoints(1); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] { createCheckpoint(0), createCheckpoint(1), @@ -96,16 +102,24 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { }; // Add checkpoints - checkpoints.addCheckpoint(expected[0]); + checkpoints.addCheckpoint(expected[0], sharedStateRegistry); assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints()); for (int i = 1; i < expected.length; i++) { - checkpoints.addCheckpoint(expected[i]); + Collection<TaskState> taskStates = expected[i - 1].getTaskStates().values(); + + checkpoints.addCheckpoint(expected[i], sharedStateRegistry); // The ZooKeeper implementation discards asynchronously expected[i - 1].awaitDiscard(); assertTrue(expected[i - 1].isDiscarded()); assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints()); + + for (TaskState taskState : taskStates) { + for (SubtaskState subtaskState : taskState.getStates()) { + verify(subtaskState, times(1)).unregisterSharedStates(sharedStateRegistry); + } + } } } @@ -132,6 +146,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { @Test public void testGetAllCheckpoints() throws Exception { CompletedCheckpointStore checkpoints = createCompletedCheckpoints(4); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] { createCheckpoint(0), createCheckpoint(1), @@ -139,7 +154,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { }; for (TestCompletedCheckpoint checkpoint : expected) { - checkpoints.addCheckpoint(checkpoint); + checkpoints.addCheckpoint(checkpoint, sharedStateRegistry); } List<CompletedCheckpoint> actual = checkpoints.getAllCheckpoints(); @@ -157,6 +172,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { @Test public void testDiscardAllCheckpoints() throws Exception { CompletedCheckpointStore checkpoints = createCompletedCheckpoints(4); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] { createCheckpoint(0), createCheckpoint(1), @@ -164,10 +180,10 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { }; for (TestCompletedCheckpoint checkpoint : expected) { - checkpoints.addCheckpoint(checkpoint); + checkpoints.addCheckpoint(checkpoint, sharedStateRegistry); } - checkpoints.shutdown(JobStatus.FINISHED); + checkpoints.shutdown(JobStatus.FINISHED, sharedStateRegistry); // Empty state assertNull(checkpoints.getLatestCheckpoint()); @@ -203,15 +219,39 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { taskGroupStates.put(jvid, taskState); for (int i = 0; i < numberOfStates; i++) { - ChainedStateHandle<StreamStateHandle> stateHandle = CheckpointCoordinatorTest.generateChainedStateHandle( - new CheckpointMessagesTest.MyHandle()); + SubtaskState subtaskState = CheckpointCoordinatorTest.mockSubtaskState(jvid, i, new KeyGroupRange(i, i)); - taskState.putState(i, new SubtaskState(stateHandle, null, null, null, null)); + taskState.putState(i, subtaskState); } return new TestCompletedCheckpoint(new JobID(), id, 0, taskGroupStates, props); } + protected void resetCheckpoint(Collection<TaskState> taskStates) { + for (TaskState taskState : taskStates) { + for (SubtaskState subtaskState : taskState.getStates()) { + Mockito.reset(subtaskState); + } + } + } + + protected void verifyCheckpointRegistered(Collection<TaskState> taskStates, SharedStateRegistry sharedStateRegistry) { + for (TaskState taskState : taskStates) { + for (SubtaskState subtaskState : taskState.getStates()) { + verify(subtaskState, times(1)).registerSharedStates(eq(sharedStateRegistry)); + } + } + } + + protected void verifyCheckpointDiscarded(Collection<TaskState> taskStates) { + for (TaskState taskState : taskStates) { + for (SubtaskState subtaskState : taskState.getStates()) { + verify(subtaskState, times(1)).discardSharedStatesOnFail(); + verify(subtaskState, times(1)).discardState(); + } + } + } + private void verifyCheckpoint(CompletedCheckpoint expected, CompletedCheckpoint actual) { assertEquals(expected, actual); } @@ -241,8 +281,8 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { } @Override - public boolean subsume() throws Exception { - if (super.subsume()) { + public boolean discardOnSubsume(SharedStateRegistry sharedStateRegistry) throws Exception { + if (super.discardOnSubsume(sharedStateRegistry)) { discard(); return true; } else { @@ -251,8 +291,8 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { } @Override - public boolean discard(JobStatus jobStatus) throws Exception { - if (super.discard(jobStatus)) { + public boolean discardOnShutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception { + if (super.discardOnShutdown(jobStatus, sharedStateRegistry)) { discard(); return true; } else {
http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java index b34e9a6..0b759d4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java @@ -23,6 +23,8 @@ import org.apache.flink.core.fs.Path; import org.apache.flink.core.testutils.CommonTestUtils; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.SharedStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.junit.Rule; import org.junit.Test; @@ -30,10 +32,12 @@ import org.junit.rules.TemporaryFolder; import org.mockito.Mockito; import java.io.File; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -61,7 +65,7 @@ public class CompletedCheckpointTest { new FileStateHandle(new Path(file.toURI()), file.length()), file.getAbsolutePath()); - checkpoint.discard(JobStatus.FAILED); + checkpoint.discardOnShutdown(JobStatus.FAILED, new SharedStateRegistry()); assertEquals(false, file.exists()); } @@ -80,10 +84,15 @@ public class CompletedCheckpointTest { CompletedCheckpoint checkpoint = new CompletedCheckpoint( new JobID(), 0, 0, 1, taskStates, props); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + checkpoint.registerSharedStates(sharedStateRegistry); + verify(state, times(1)).registerSharedStates(sharedStateRegistry); + // Subsume - checkpoint.subsume(); + checkpoint.discardOnSubsume(sharedStateRegistry); verify(state, times(1)).discardState(); + verify(state, times(1)).unregisterSharedStates(sharedStateRegistry); } /** @@ -112,17 +121,22 @@ public class CompletedCheckpointTest { new FileStateHandle(new Path(file.toURI()), file.length()), externalPath); - checkpoint.discard(status); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + checkpoint.registerSharedStates(sharedStateRegistry); + + checkpoint.discardOnShutdown(status, sharedStateRegistry); verify(state, times(0)).discardState(); assertEquals(true, file.exists()); + verify(state, times(0)).unregisterSharedStates(sharedStateRegistry); // Discard props = new CheckpointProperties(false, false, true, true, true, true, true); checkpoint = new CompletedCheckpoint( new JobID(), 0, 0, 1, new HashMap<>(taskStates), props); - checkpoint.discard(status); + checkpoint.discardOnShutdown(status, sharedStateRegistry); verify(state, times(1)).discardState(); + verify(state, times(1)).unregisterSharedStates(sharedStateRegistry); } } @@ -146,7 +160,7 @@ public class CompletedCheckpointTest { CompletedCheckpointStats.DiscardCallback callback = mock(CompletedCheckpointStats.DiscardCallback.class); completed.setDiscardCallback(callback); - completed.discard(JobStatus.FINISHED); + completed.discardOnShutdown(JobStatus.FINISHED, new SharedStateRegistry()); verify(callback, times(1)).notifyDiscardedCheckpoint(); } http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java index 0ab031e..e7c1c3b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java @@ -31,6 +31,7 @@ import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.jobmanager.scheduler.Scheduler; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.util.SerializedValue; @@ -40,6 +41,8 @@ import org.mockito.Matchers; import java.net.URL; import java.util.Collections; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -59,7 +62,7 @@ public class ExecutionGraphCheckpointCoordinatorTest { graph.fail(new Exception("Test Exception")); verify(counter, times(1)).shutdown(JobStatus.FAILED); - verify(store, times(1)).shutdown(JobStatus.FAILED); + verify(store, times(1)).shutdown(eq(JobStatus.FAILED), any(SharedStateRegistry.class)); } /** @@ -75,8 +78,8 @@ public class ExecutionGraphCheckpointCoordinatorTest { graph.suspend(new Exception("Test Exception")); // No shutdown - verify(counter, times(1)).shutdown(Matchers.eq(JobStatus.SUSPENDED)); - verify(store, times(1)).shutdown(Matchers.eq(JobStatus.SUSPENDED)); + verify(counter, times(1)).shutdown(eq(JobStatus.SUSPENDED)); + verify(store, times(1)).shutdown(eq(JobStatus.SUSPENDED), any(SharedStateRegistry.class)); } private ExecutionGraph createExecutionGraphAndEnableCheckpointing( http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java index a15684c..d77fac1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java @@ -24,15 +24,19 @@ import org.apache.flink.runtime.concurrent.Future; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.SharedStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.mockito.Mock; import org.mockito.Mockito; import java.io.File; import java.lang.reflect.Field; import java.util.ArrayDeque; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Queue; @@ -45,7 +49,10 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.powermock.api.mockito.PowerMockito.when; @@ -184,9 +191,12 @@ public class PendingCheckpointTest { @SuppressWarnings("unchecked") public void testAbortDiscardsState() throws Exception { CheckpointProperties props = new CheckpointProperties(false, true, false, false, false, false, false); - TaskState state = mock(TaskState.class); QueueExecutor executor = new QueueExecutor(); + TaskState state = mock(TaskState.class); + doNothing().when(state).registerSharedStates(any(SharedStateRegistry.class)); + doNothing().when(state).unregisterSharedStates(any(SharedStateRegistry.class)); + String targetDir = tmpFolder.newFolder().getAbsolutePath(); // Abort declined @@ -197,6 +207,7 @@ public class PendingCheckpointTest { // execute asynchronous discard operation executor.runQueuedCommands(); verify(state, times(1)).discardState(); + verify(state, times(1)).discardSharedStatesOnFail(); // Abort error Mockito.reset(state); @@ -208,6 +219,7 @@ public class PendingCheckpointTest { // execute asynchronous discard operation executor.runQueuedCommands(); verify(state, times(1)).discardState(); + verify(state, times(1)).discardSharedStatesOnFail(); // Abort expired Mockito.reset(state); @@ -219,6 +231,7 @@ public class PendingCheckpointTest { // execute asynchronous discard operation executor.runQueuedCommands(); verify(state, times(1)).discardState(); + verify(state, times(1)).discardSharedStatesOnFail(); // Abort subsumed Mockito.reset(state); @@ -230,6 +243,7 @@ public class PendingCheckpointTest { // execute asynchronous discard operation executor.runQueuedCommands(); verify(state, times(1)).discardState(); + verify(state, times(1)).discardSharedStatesOnFail(); } /** @@ -340,7 +354,11 @@ public class PendingCheckpointTest { return createPendingCheckpoint(props, targetDirectory, Executors.directExecutor()); } - private static PendingCheckpoint createPendingCheckpoint(CheckpointProperties props, String targetDirectory, Executor executor) { + private static PendingCheckpoint createPendingCheckpoint( + CheckpointProperties props, + String targetDirectory, + Executor executor) { + Map<ExecutionAttemptID, ExecutionVertex> ackTasks = new HashMap<>(ACK_TASKS); return new PendingCheckpoint( new JobID(), http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java index cc7b2d0..7a85897 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java @@ -19,9 +19,12 @@ package org.apache.flink.runtime.checkpoint; import org.apache.flink.runtime.jobgraph.JobStatus; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.junit.Test; import java.io.IOException; +import java.util.Collection; +import java.util.Collections; import java.util.List; import static org.junit.Assert.assertEquals; @@ -49,15 +52,18 @@ public class StandaloneCompletedCheckpointStoreTest extends CompletedCheckpointS @Test public void testShutdownDiscardsCheckpoints() throws Exception { CompletedCheckpointStore store = createCompletedCheckpoints(1); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint checkpoint = createCheckpoint(0); + Collection<TaskState> taskStates = checkpoint.getTaskStates().values(); - store.addCheckpoint(checkpoint); + store.addCheckpoint(checkpoint, sharedStateRegistry); assertEquals(1, store.getNumberOfRetainedCheckpoints()); + verifyCheckpointRegistered(taskStates, sharedStateRegistry); - store.shutdown(JobStatus.FINISHED); - + store.shutdown(JobStatus.FINISHED, sharedStateRegistry); assertEquals(0, store.getNumberOfRetainedCheckpoints()); assertTrue(checkpoint.isDiscarded()); + verifyCheckpointDiscarded(taskStates); } /** @@ -67,15 +73,18 @@ public class StandaloneCompletedCheckpointStoreTest extends CompletedCheckpointS @Test public void testSuspendDiscardsCheckpoints() throws Exception { CompletedCheckpointStore store = createCompletedCheckpoints(1); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint checkpoint = createCheckpoint(0); + Collection<TaskState> taskStates = checkpoint.getTaskStates().values(); - store.addCheckpoint(checkpoint); + store.addCheckpoint(checkpoint, sharedStateRegistry); assertEquals(1, store.getNumberOfRetainedCheckpoints()); + verifyCheckpointRegistered(taskStates, sharedStateRegistry); - store.shutdown(JobStatus.SUSPENDED); - + store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry); assertEquals(0, store.getNumberOfRetainedCheckpoints()); assertTrue(checkpoint.isDiscarded()); + verifyCheckpointDiscarded(taskStates); } /** @@ -87,14 +96,16 @@ public class StandaloneCompletedCheckpointStoreTest extends CompletedCheckpointS final int numCheckpointsToRetain = 1; CompletedCheckpointStore store = createCompletedCheckpoints(numCheckpointsToRetain); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); for (long i = 0; i <= numCheckpointsToRetain; ++i) { CompletedCheckpoint checkpointToAdd = mock(CompletedCheckpoint.class); doReturn(i).when(checkpointToAdd).getCheckpointID(); - doThrow(new IOException()).when(checkpointToAdd).subsume(); + doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates(); + doThrow(new IOException()).when(checkpointToAdd).discardOnSubsume(sharedStateRegistry); try { - store.addCheckpoint(checkpointToAdd); + store.addCheckpoint(checkpointToAdd, sharedStateRegistry); // The checkpoint should be in the store if we successfully add it into the store. List<CompletedCheckpoint> addedCheckpoints = store.getAllCheckpoints(); http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java index 625999a..607e773 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java @@ -22,6 +22,7 @@ import org.apache.curator.framework.CuratorFramework; import org.apache.flink.runtime.concurrent.Executors; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.state.RetrievableStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper; import org.apache.flink.runtime.zookeeper.ZooKeeperTestEnvironment; import org.junit.AfterClass; @@ -80,22 +81,32 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint @Test public void testRecover() throws Exception { CompletedCheckpointStore checkpoints = createCompletedCheckpoints(3); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] { createCheckpoint(0), createCheckpoint(1), createCheckpoint(2) }; // Add multiple checkpoints - checkpoints.addCheckpoint(expected[0]); - checkpoints.addCheckpoint(expected[1]); - checkpoints.addCheckpoint(expected[2]); + checkpoints.addCheckpoint(expected[0], sharedStateRegistry); + checkpoints.addCheckpoint(expected[1], sharedStateRegistry); + checkpoints.addCheckpoint(expected[2], sharedStateRegistry); + + verifyCheckpointRegistered(expected[0].getTaskStates().values(), sharedStateRegistry); + verifyCheckpointRegistered(expected[1].getTaskStates().values(), sharedStateRegistry); + verifyCheckpointRegistered(expected[2].getTaskStates().values(), sharedStateRegistry); // All three should be in ZK assertEquals(3, ZooKeeper.getClient().getChildren().forPath(CheckpointsPath).size()); assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints()); + resetCheckpoint(expected[0].getTaskStates().values()); + resetCheckpoint(expected[1].getTaskStates().values()); + resetCheckpoint(expected[2].getTaskStates().values()); + // Recover - checkpoints.recover(); + SharedStateRegistry newSharedStateRegistry = new SharedStateRegistry(); + checkpoints.recover(newSharedStateRegistry); assertEquals(3, ZooKeeper.getClient().getChildren().forPath(CheckpointsPath).size()); assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints()); @@ -106,11 +117,15 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint expectedCheckpoints.add(expected[2]); expectedCheckpoints.add(createCheckpoint(3)); - checkpoints.addCheckpoint(expectedCheckpoints.get(2)); + checkpoints.addCheckpoint(expectedCheckpoints.get(2), newSharedStateRegistry); List<CompletedCheckpoint> actualCheckpoints = checkpoints.getAllCheckpoints(); assertEquals(expectedCheckpoints, actualCheckpoints); + + for (CompletedCheckpoint actualCheckpoint : actualCheckpoints) { + verifyCheckpointRegistered(actualCheckpoint.getTaskStates().values(), newSharedStateRegistry); + } } /** @@ -121,18 +136,18 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint CuratorFramework client = ZooKeeper.getClient(); CompletedCheckpointStore store = createCompletedCheckpoints(1); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint checkpoint = createCheckpoint(0); - store.addCheckpoint(checkpoint); + store.addCheckpoint(checkpoint, sharedStateRegistry); assertEquals(1, store.getNumberOfRetainedCheckpoints()); assertNotNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID())); - store.shutdown(JobStatus.FINISHED); - + store.shutdown(JobStatus.FINISHED, sharedStateRegistry); assertEquals(0, store.getNumberOfRetainedCheckpoints()); assertNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID())); - store.recover(); + store.recover(sharedStateRegistry); assertEquals(0, store.getNumberOfRetainedCheckpoints()); } @@ -146,19 +161,20 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint CuratorFramework client = ZooKeeper.getClient(); CompletedCheckpointStore store = createCompletedCheckpoints(1); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint checkpoint = createCheckpoint(0); - store.addCheckpoint(checkpoint); + store.addCheckpoint(checkpoint, sharedStateRegistry); assertEquals(1, store.getNumberOfRetainedCheckpoints()); assertNotNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID())); - store.shutdown(JobStatus.SUSPENDED); + store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry); assertEquals(0, store.getNumberOfRetainedCheckpoints()); assertNotNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID())); // Recover again - store.recover(); + store.recover(sharedStateRegistry); CompletedCheckpoint recovered = store.getLatestCheckpoint(); assertEquals(checkpoint, recovered); http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java index aa2ec85..1f5731d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java @@ -27,6 +27,7 @@ import org.apache.curator.utils.EnsurePath; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.runtime.concurrent.Executors; import org.apache.flink.runtime.state.RetrievableStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper; import org.apache.flink.runtime.zookeeper.ZooKeeperStateHandleStore; import org.apache.flink.util.TestLogger; @@ -40,6 +41,7 @@ import org.powermock.modules.junit4.PowerMockRunner; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.concurrent.Executor; @@ -158,7 +160,9 @@ public class ZooKeeperCompletedCheckpointStoreTest extends TestLogger { stateSotrage, Executors.directExecutor()); - zooKeeperCompletedCheckpointStore.recover(); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + + zooKeeperCompletedCheckpointStore.recover(sharedStateRegistry); CompletedCheckpoint latestCompletedCheckpoint = zooKeeperCompletedCheckpointStore.getLatestCheckpoint(); @@ -222,14 +226,17 @@ public class ZooKeeperCompletedCheckpointStoreTest extends TestLogger { checkpointsPath, stateSotrage, Executors.directExecutor()); + + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); for (long i = 0; i <= numCheckpointsToRetain; ++i) { CompletedCheckpoint checkpointToAdd = mock(CompletedCheckpoint.class); doReturn(i).when(checkpointToAdd).getCheckpointID(); + doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates(); try { - zooKeeperCompletedCheckpointStore.addCheckpoint(checkpointToAdd); + zooKeeperCompletedCheckpointStore.addCheckpoint(checkpointToAdd, sharedStateRegistry); // The checkpoint should be in the store if we successfully add it into the store. List<CompletedCheckpoint> addedCheckpoints = zooKeeperCompletedCheckpointStore.getAllCheckpoints(); http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java index 6eacaac..77eb566 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java @@ -71,6 +71,7 @@ import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService; import org.apache.flink.runtime.messages.JobManagerMessages; import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; @@ -81,6 +82,7 @@ import org.apache.flink.runtime.testingUtils.TestingMessages; import org.apache.flink.runtime.testingUtils.TestingTaskManager; import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages; import org.apache.flink.runtime.testingUtils.TestingUtils; +import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore; import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare; import org.apache.flink.util.InstantiationUtil; @@ -164,7 +166,7 @@ public class JobManagerHARecoveryTest { Scheduler scheduler = new Scheduler(TestingUtils.defaultExecutionContext()); MySubmittedJobGraphStore mySubmittedJobGraphStore = new MySubmittedJobGraphStore(); - MyCheckpointStore checkpointStore = new MyCheckpointStore(); + CompletedCheckpointStore checkpointStore = new RecoverableCompletedCheckpointStore(); CheckpointIDCounter checkpointCounter = new StandaloneCheckpointIDCounter(); CheckpointRecoveryFactory checkpointStateFactory = new MyCheckpointRecoveryFactory(checkpointStore, checkpointCounter); TestingLeaderElectionService myLeaderElectionService = new TestingLeaderElectionService(); @@ -438,67 +440,6 @@ public class JobManagerHARecoveryTest { } } - /** - * A checkpoint store, which supports shutdown and suspend. You can use this to test HA - * as long as the factory always returns the same store instance. - */ - static class MyCheckpointStore implements CompletedCheckpointStore { - - private final ArrayDeque<CompletedCheckpoint> checkpoints = new ArrayDeque<>(2); - - private final ArrayDeque<CompletedCheckpoint> suspended = new ArrayDeque<>(2); - - @Override - public void recover() throws Exception { - checkpoints.addAll(suspended); - suspended.clear(); - } - - @Override - public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception { - checkpoints.addLast(checkpoint); - if (checkpoints.size() > 1) { - checkpoints.removeFirst().subsume(); - } - } - - @Override - public CompletedCheckpoint getLatestCheckpoint() throws Exception { - return checkpoints.isEmpty() ? null : checkpoints.getLast(); - } - - @Override - public void shutdown(JobStatus jobStatus) throws Exception { - if (jobStatus.isGloballyTerminalState()) { - checkpoints.clear(); - suspended.clear(); - } else { - suspended.addAll(checkpoints); - checkpoints.clear(); - } - } - - @Override - public List<CompletedCheckpoint> getAllCheckpoints() throws Exception { - return new ArrayList<>(checkpoints); - } - - @Override - public int getNumberOfRetainedCheckpoints() { - return checkpoints.size(); - } - - @Override - public int getMaxNumberOfRetainedCheckpoints() { - return 1; - } - - @Override - public boolean requiresExternalizedCheckpoints() { - return false; - } - } - static class MyCheckpointRecoveryFactory implements CheckpointRecoveryFactory { private final CompletedCheckpointStore store; http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java new file mode 100644 index 0000000..cb14ff0 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.flink.runtime.state; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class SharedStateRegistryTest { + + /** + * Validate that all states can be correctly registered at the registry. + */ + @Test + public void testRegistryNormal() { + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + + // register one state + TestSharedState firstState = new TestSharedState("first"); + sharedStateRegistry.register(firstState, true); + assertEquals(1, sharedStateRegistry.getReferenceCount(firstState)); + + // register another state + TestSharedState secondState = new TestSharedState("second"); + sharedStateRegistry.register(secondState, true); + assertEquals(1, sharedStateRegistry.getReferenceCount(secondState)); + + // register the first state again + sharedStateRegistry.register(firstState, false); + assertEquals(2, sharedStateRegistry.getReferenceCount(firstState)); + + // unregister the second state + sharedStateRegistry.unregister(secondState); + assertEquals(0, sharedStateRegistry.getReferenceCount(secondState)); + + // unregister the first state + sharedStateRegistry.unregister(firstState); + assertEquals(1, sharedStateRegistry.getReferenceCount(firstState)); + } + + /** + * Validate that registering a handle referencing uncreated state will throw exception + */ + @Test(expected = IllegalStateException.class) + public void testRegisterWithUncreatedReference() { + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + + // register one state + TestSharedState state = new TestSharedState("state"); + sharedStateRegistry.register(state, false); + } + + /** + * Validate that registering duplicate creation of the same state will throw exception + */ + @Test(expected = IllegalStateException.class) + public void testRegisterWithDuplicateState() { + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + + // register one state + TestSharedState state = new TestSharedState("state"); + sharedStateRegistry.register(state, true); + sharedStateRegistry.register(state, true); + } + + /** + * Validate that unregister an unexisted key will throw exception + */ + @Test(expected = IllegalStateException.class) + public void testUnregisterWithUnexistedKey() { + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + + sharedStateRegistry.unregister(new TestSharedState("unexisted")); + } + + private static class TestSharedState implements SharedStateHandle { + private static final long serialVersionUID = 4468635881465159780L; + + private String key; + + TestSharedState(String key) { + this.key = key; + } + + @Override + public String getKey() { + return key; + } + + @Override + public void discardState() throws Exception { + // nothing to do + } + + @Override + public long getStateSize() { + return key.length(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + TestSharedState testState = (TestSharedState) o; + + return key.equals(testState.key); + } + + @Override + public int hashCode() { + return key.hashCode(); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java new file mode 100644 index 0000000..75b0f6f --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.testutils; + +import org.apache.flink.runtime.checkpoint.CompletedCheckpoint; +import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore; +import org.apache.flink.runtime.jobgraph.JobStatus; +import org.apache.flink.runtime.state.SharedStateRegistry; +import org.apache.flink.runtime.state.StateObject; +import org.apache.flink.runtime.state.StateUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; + +/** + * A checkpoint store, which supports shutdown and suspend. You can use this to test HA + * as long as the factory always returns the same store instance. + */ +public class RecoverableCompletedCheckpointStore implements CompletedCheckpointStore { + + private static final Logger LOG = LoggerFactory.getLogger(RecoverableCompletedCheckpointStore.class); + + private final ArrayDeque<CompletedCheckpoint> checkpoints = new ArrayDeque<>(2); + + private final ArrayDeque<CompletedCheckpoint> suspended = new ArrayDeque<>(2); + + @Override + public void recover(SharedStateRegistry sharedStateRegistry) throws Exception { + checkpoints.addAll(suspended); + suspended.clear(); + + for (CompletedCheckpoint checkpoint : checkpoints) { + checkpoint.registerSharedStates(sharedStateRegistry); + } + } + + @Override + public void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception { + checkpoints.addLast(checkpoint); + + checkpoint.registerSharedStates(sharedStateRegistry); + + if (checkpoints.size() > 1) { + CompletedCheckpoint checkpointToSubsume = checkpoints.removeFirst(); + checkpointToSubsume.discardOnSubsume(sharedStateRegistry); + } + } + + @Override + public CompletedCheckpoint getLatestCheckpoint() throws Exception { + return checkpoints.isEmpty() ? null : checkpoints.getLast(); + } + + @Override + public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception { + if (jobStatus.isGloballyTerminalState()) { + checkpoints.clear(); + suspended.clear(); + } else { + suspended.clear(); + + for (CompletedCheckpoint checkpoint : checkpoints) { + sharedStateRegistry.unregisterAll(checkpoint.getTaskStates().values()); + suspended.add(checkpoint); + } + + checkpoints.clear(); + } + } + + @Override + public List<CompletedCheckpoint> getAllCheckpoints() throws Exception { + return new ArrayList<>(checkpoints); + } + + @Override + public int getNumberOfRetainedCheckpoints() { + return checkpoints.size(); + } + + @Override + public int getMaxNumberOfRetainedCheckpoints() { + return 1; + } + + @Override + public boolean requiresExternalizedCheckpoints() { + return false; + } +}
