http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java index 344b340..88b95f5 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java @@ -23,14 +23,15 @@ import org.apache.flink.runtime.concurrent.Executors; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobStatus; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.TestLogger; + import org.junit.Test; import org.junit.runner.RunWith; import org.powermock.core.classloader.annotations.PrepareForTest; @@ -42,8 +43,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import static org.mockito.Matchers.anyInt; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -89,29 +90,26 @@ public class CheckpointCoordinatorFailureTest extends TestLogger { assertFalse(pendingCheckpoint.isDiscarded()); final long checkpointId = coord.getPendingCheckpoints().keySet().iterator().next(); - - SubtaskState subtaskState = mock(SubtaskState.class); + StreamStateHandle legacyHandle = mock(StreamStateHandle.class); - ChainedStateHandle<StreamStateHandle> chainedLegacyHandle = mock(ChainedStateHandle.class); - when(chainedLegacyHandle.get(anyInt())).thenReturn(legacyHandle); - when(subtaskState.getLegacyOperatorState()).thenReturn(chainedLegacyHandle); + KeyedStateHandle managedKeyedHandle = mock(KeyedStateHandle.class); + KeyedStateHandle rawKeyedHandle = mock(KeyedStateHandle.class); + OperatorStateHandle managedOpHandle = mock(OperatorStateHandle.class); + OperatorStateHandle rawOpHandle = mock(OperatorStateHandle.class); - OperatorStateHandle managedHandle = mock(OperatorStateHandle.class); - ChainedStateHandle<OperatorStateHandle> chainedManagedHandle = mock(ChainedStateHandle.class); - when(chainedManagedHandle.get(anyInt())).thenReturn(managedHandle); - when(subtaskState.getManagedOperatorState()).thenReturn(chainedManagedHandle); + final OperatorSubtaskState operatorSubtaskState = spy(new OperatorSubtaskState( + legacyHandle, + managedOpHandle, + rawOpHandle, + managedKeyedHandle, + rawKeyedHandle)); - OperatorStateHandle rawHandle = mock(OperatorStateHandle.class); - ChainedStateHandle<OperatorStateHandle> chainedRawHandle = mock(ChainedStateHandle.class); - when(chainedRawHandle.get(anyInt())).thenReturn(rawHandle); - when(subtaskState.getRawOperatorState()).thenReturn(chainedRawHandle); + TaskStateSnapshot subtaskState = spy(new TaskStateSnapshot()); + subtaskState.putSubtaskStateByOperatorID(new OperatorID(), operatorSubtaskState); + + when(subtaskState.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(vertex.getJobvertexId()))).thenReturn(operatorSubtaskState); - KeyedStateHandle managedKeyedHandle = mock(KeyedStateHandle.class); - when(subtaskState.getRawKeyedState()).thenReturn(managedKeyedHandle); - KeyedStateHandle managedRawHandle = mock(KeyedStateHandle.class); - when(subtaskState.getManagedKeyedState()).thenReturn(managedRawHandle); - AcknowledgeCheckpoint acknowledgeMessage = new AcknowledgeCheckpoint(jid, executionAttemptId, checkpointId, new CheckpointMetrics(), subtaskState); try { @@ -126,11 +124,12 @@ public class CheckpointCoordinatorFailureTest extends TestLogger { assertTrue(pendingCheckpoint.isDiscarded()); // make sure that the subtask state has been discarded after we could not complete it. - verify(subtaskState.getLegacyOperatorState().get(0)).discardState(); - verify(subtaskState.getManagedOperatorState().get(0)).discardState(); - verify(subtaskState.getRawOperatorState().get(0)).discardState(); - verify(subtaskState.getManagedKeyedState()).discardState(); - verify(subtaskState.getRawKeyedState()).discardState(); + verify(operatorSubtaskState).discardState(); + verify(operatorSubtaskState.getLegacyOperatorState()).discardState(); + verify(operatorSubtaskState.getManagedOperatorState().iterator().next()).discardState(); + verify(operatorSubtaskState.getRawOperatorState().iterator().next()).discardState(); + verify(operatorSubtaskState.getManagedKeyedState().iterator().next()).discardState(); + verify(operatorSubtaskState.getRawKeyedState().iterator().next()).discardState(); } private static final class FailingCompletedCheckpointStore implements CompletedCheckpointStore {
http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index cb92df6..d9af879 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -44,7 +44,6 @@ import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.testutils.CommonTestUtils; @@ -93,7 +92,6 @@ import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -102,7 +100,6 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.mockito.Mockito.withSettings; /** * Tests for the checkpoint coordinator. @@ -555,31 +552,29 @@ public class CheckpointCoordinatorTest extends TestLogger { assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); - OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId()); - OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId()); - - Map<OperatorID, OperatorState> operatorStates = checkpoint.getOperatorStates(); - - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); - operatorStates.put(opID2, new SpyInjectingOperatorState( - opID2, vertex2.getTotalNumberOfParallelSubtasks(), vertex2.getMaxParallelism())); - // check that the vertices received the trigger checkpoint message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class)); verify(vertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class)); } + OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId()); + OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId()); + TaskStateSnapshot taskOperatorSubtaskStates1 = mock(TaskStateSnapshot.class); + TaskStateSnapshot taskOperatorSubtaskStates2 = mock(TaskStateSnapshot.class); + OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState2 = mock(OperatorSubtaskState.class); + when(taskOperatorSubtaskStates1.getSubtaskStateByOperatorID(opID1)).thenReturn(subtaskState1); + when(taskOperatorSubtaskStates2.getSubtaskStateByOperatorID(opID2)).thenReturn(subtaskState2); + // acknowledge from one of the tasks - AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)); + AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates2); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1); - OperatorSubtaskState subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(1, checkpoint.getNumberOfAcknowledgedTasks()); assertEquals(1, checkpoint.getNumberOfNonAcknowledgedTasks()); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); - verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class)); + verify(taskOperatorSubtaskStates2, never()).registerSharedStates(any(SharedStateRegistry.class)); // acknowledge the same task again (should not matter) coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1); @@ -588,8 +583,7 @@ public class CheckpointCoordinatorTest extends TestLogger { verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class)); // acknowledge the other task. - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates1)); // the checkpoint is internally converted to a successful checkpoint and the // pending checkpoint object is disposed @@ -628,9 +622,7 @@ public class CheckpointCoordinatorTest extends TestLogger { long checkpointIdNew = coord.getPendingCheckpoints().entrySet().iterator().next().getKey(); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew)); - subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew)); - subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); @@ -852,18 +844,20 @@ public class CheckpointCoordinatorTest extends TestLogger { OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId()); OperatorID opID3 = OperatorID.fromJobVertexID(ackVertex3.getJobvertexId()); - Map<OperatorID, OperatorState> operatorStates1 = pending1.getOperatorStates(); + TaskStateSnapshot taskOperatorSubtaskStates1_1 = spy(new TaskStateSnapshot()); + TaskStateSnapshot taskOperatorSubtaskStates1_2 = spy(new TaskStateSnapshot()); + TaskStateSnapshot taskOperatorSubtaskStates1_3 = spy(new TaskStateSnapshot()); - operatorStates1.put(opID1, new SpyInjectingOperatorState( - opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); - operatorStates1.put(opID2, new SpyInjectingOperatorState( - opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); - operatorStates1.put(opID3, new SpyInjectingOperatorState( - opID3, ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism())); + OperatorSubtaskState subtaskState1_1 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState1_2 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState1_3 = mock(OperatorSubtaskState.class); + taskOperatorSubtaskStates1_1.putSubtaskStateByOperatorID(opID1, subtaskState1_1); + taskOperatorSubtaskStates1_2.putSubtaskStateByOperatorID(opID2, subtaskState1_2); + taskOperatorSubtaskStates1_3.putSubtaskStateByOperatorID(opID3, subtaskState1_3); // acknowledge one of the three tasks - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState1_2 = operatorStates1.get(opID2).getState(ackVertex2.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_2)); + // start the second checkpoint // trigger the first checkpoint. this should succeed assertTrue(coord.triggerCheckpoint(timestamp2, false)); @@ -880,14 +874,17 @@ public class CheckpointCoordinatorTest extends TestLogger { } long checkpointId2 = pending2.getCheckpointId(); - Map<OperatorID, OperatorState> operatorStates2 = pending2.getOperatorStates(); + TaskStateSnapshot taskOperatorSubtaskStates2_1 = spy(new TaskStateSnapshot()); + TaskStateSnapshot taskOperatorSubtaskStates2_2 = spy(new TaskStateSnapshot()); + TaskStateSnapshot taskOperatorSubtaskStates2_3 = spy(new TaskStateSnapshot()); + + OperatorSubtaskState subtaskState2_1 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState2_2 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState2_3 = mock(OperatorSubtaskState.class); - operatorStates2.put(opID1, new SpyInjectingOperatorState( - opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); - operatorStates2.put(opID2, new SpyInjectingOperatorState( - opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); - operatorStates2.put(opID3, new SpyInjectingOperatorState( - opID3, ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism())); + taskOperatorSubtaskStates2_1.putSubtaskStateByOperatorID(opID1, subtaskState2_1); + taskOperatorSubtaskStates2_2.putSubtaskStateByOperatorID(opID2, subtaskState2_2); + taskOperatorSubtaskStates2_3.putSubtaskStateByOperatorID(opID3, subtaskState2_3); // trigger messages should have been sent verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class)); @@ -896,17 +893,13 @@ public class CheckpointCoordinatorTest extends TestLogger { // we acknowledge one more task from the first checkpoint and the second // checkpoint completely. The second checkpoint should then subsume the first checkpoint - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState2_3 = operatorStates2.get(opID3).getState(ackVertex3.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_3)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState2_1 = operatorStates2.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_1)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState1_1 = operatorStates1.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_1)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState2_2 = operatorStates2.get(opID2).getState(ackVertex2.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_2)); // now, the second checkpoint should be confirmed, and the first discarded // actually both pending checkpoints are discarded, and the second has been transformed @@ -938,8 +931,7 @@ public class CheckpointCoordinatorTest extends TestLogger { verify(commitVertex.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2)); // send the last remaining ack for the first checkpoint. This should not do anything - SubtaskState subtaskState1_3 = mock(SubtaskState.class); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), subtaskState1_3)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_3)); verify(subtaskState1_3, times(1)).discardState(); coord.shutdown(JobStatus.FINISHED); @@ -1005,13 +997,11 @@ public class CheckpointCoordinatorTest extends TestLogger { OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId()); - Map<OperatorID, OperatorState> operatorStates = checkpoint.getOperatorStates(); + TaskStateSnapshot taskOperatorSubtaskStates1 = spy(new TaskStateSnapshot()); + OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class); + taskOperatorSubtaskStates1.putSubtaskStateByOperatorID(opID1, subtaskState1); - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); - - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState = operatorStates.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), taskOperatorSubtaskStates1)); // wait until the checkpoint must have expired. // we check every 250 msecs conservatively for 5 seconds @@ -1029,7 +1019,7 @@ public class CheckpointCoordinatorTest extends TestLogger { assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); // validate that the received states have been discarded - verify(subtaskState, times(1)).discardState(); + verify(subtaskState1, times(1)).discardState(); // no confirm message must have been sent verify(commitVertex.getCurrentExecutionAttempt(), times(0)).notifyCheckpointComplete(anyLong(), anyLong()); @@ -1147,26 +1137,18 @@ public class CheckpointCoordinatorTest extends TestLogger { long checkpointId = pendingCheckpoint.getCheckpointId(); OperatorID opIDtrigger = OperatorID.fromJobVertexID(triggerVertex.getJobvertexId()); - OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId()); - OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId()); - - Map<OperatorID, OperatorState> operatorStates = pendingCheckpoint.getOperatorStates(); - operatorStates.put(opIDtrigger, new SpyInjectingOperatorState( - opIDtrigger, triggerVertex.getTotalNumberOfParallelSubtasks(), triggerVertex.getMaxParallelism())); - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); - operatorStates.put(opID2, new SpyInjectingOperatorState( - opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); + TaskStateSnapshot taskOperatorSubtaskStatesTrigger = spy(new TaskStateSnapshot()); + OperatorSubtaskState subtaskStateTrigger = mock(OperatorSubtaskState.class); + taskOperatorSubtaskStatesTrigger.putSubtaskStateByOperatorID(opIDtrigger, subtaskStateTrigger); // acknowledge the first trigger vertex - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState storedTriggerSubtaskState = operatorStates.get(opIDtrigger).getState(triggerVertex.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStatesTrigger)); // verify that the subtask state has not been discarded - verify(storedTriggerSubtaskState, never()).discardState(); + verify(subtaskStateTrigger, never()).discardState(); - SubtaskState unknownSubtaskState = mock(SubtaskState.class); + TaskStateSnapshot unknownSubtaskState = mock(TaskStateSnapshot.class); // receive an acknowledge message for an unknown vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState)); @@ -1174,7 +1156,7 @@ public class CheckpointCoordinatorTest extends TestLogger { // we should discard acknowledge messages from an unknown vertex belonging to our job verify(unknownSubtaskState, times(1)).discardState(); - SubtaskState differentJobSubtaskState = mock(SubtaskState.class); + TaskStateSnapshot differentJobSubtaskState = mock(TaskStateSnapshot.class); // receive an acknowledge message from an unknown job coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState)); @@ -1183,22 +1165,22 @@ public class CheckpointCoordinatorTest extends TestLogger { verify(differentJobSubtaskState, never()).discardState(); // duplicate acknowledge message for the trigger vertex - SubtaskState triggerSubtaskState = mock(SubtaskState.class); + TaskStateSnapshot triggerSubtaskState = mock(TaskStateSnapshot.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState)); // duplicate acknowledge messages for a known vertex should not trigger discarding the state verify(triggerSubtaskState, never()).discardState(); // let the checkpoint fail at the first ack vertex - reset(storedTriggerSubtaskState); + reset(subtaskStateTrigger); coord.receiveDeclineMessage(new DeclineCheckpoint(jobId, ackAttemptId1, checkpointId)); assertTrue(pendingCheckpoint.isDiscarded()); // check that we've cleaned up the already acknowledged state - verify(storedTriggerSubtaskState, times(1)).discardState(); + verify(subtaskStateTrigger, times(1)).discardState(); - SubtaskState ackSubtaskState = mock(SubtaskState.class); + TaskStateSnapshot ackSubtaskState = mock(TaskStateSnapshot.class); // late acknowledge message from the second ack vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, ackAttemptId2, checkpointId, new CheckpointMetrics(), ackSubtaskState)); @@ -1213,7 +1195,7 @@ public class CheckpointCoordinatorTest extends TestLogger { // we should not interfere with different jobs verify(differentJobSubtaskState, never()).discardState(); - SubtaskState unknownSubtaskState2 = mock(SubtaskState.class); + TaskStateSnapshot unknownSubtaskState2 = mock(TaskStateSnapshot.class); // receive an acknowledge message for an unknown vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState2)); @@ -1470,18 +1452,16 @@ public class CheckpointCoordinatorTest extends TestLogger { OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId()); OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId()); - - Map<OperatorID, OperatorState> operatorStates = pending.getOperatorStates(); - - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); - operatorStates.put(opID2, new SpyInjectingOperatorState( - opID2, vertex2.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); + TaskStateSnapshot taskOperatorSubtaskStates1 = mock(TaskStateSnapshot.class); + TaskStateSnapshot taskOperatorSubtaskStates2 = mock(TaskStateSnapshot.class); + OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState2 = mock(OperatorSubtaskState.class); + when(taskOperatorSubtaskStates1.getSubtaskStateByOperatorID(opID1)).thenReturn(subtaskState1); + when(taskOperatorSubtaskStates2.getSubtaskStateByOperatorID(opID2)).thenReturn(subtaskState2); // acknowledge from one of the tasks - AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)); + AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates2); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2); - OperatorSubtaskState subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(1, pending.getNumberOfAcknowledgedTasks()); assertEquals(1, pending.getNumberOfNonAcknowledgedTasks()); assertFalse(pending.isDiscarded()); @@ -1495,8 +1475,7 @@ public class CheckpointCoordinatorTest extends TestLogger { assertFalse(savepointFuture.isDone()); // acknowledge the other task. - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates1)); // the checkpoint is internally converted to a successful checkpoint and the // pending checkpoint object is disposed @@ -1536,9 +1515,6 @@ public class CheckpointCoordinatorTest extends TestLogger { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew)); - subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); - subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); - assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); @@ -2037,20 +2013,8 @@ public class CheckpointCoordinatorTest extends TestLogger { List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); - PendingCheckpoint pending = coord.getPendingCheckpoints().get(checkpointId); - - OperatorID opID1 = OperatorID.fromJobVertexID(jobVertexID1); - OperatorID opID2 = OperatorID.fromJobVertexID(jobVertexID2); - - Map<OperatorID, OperatorState> operatorStates = pending.getOperatorStates(); - - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, jobVertex1.getParallelism(), jobVertex1.getMaxParallelism())); - operatorStates.put(opID2, new SpyInjectingOperatorState( - opID2, jobVertex2.getParallelism(), jobVertex2.getMaxParallelism())); - for (int index = 0; index < jobVertex1.getParallelism(); index++) { - SubtaskState subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index)); + TaskStateSnapshot subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index)); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, @@ -2063,7 +2027,7 @@ public class CheckpointCoordinatorTest extends TestLogger { } for (int index = 0; index < jobVertex2.getParallelism(); index++) { - SubtaskState subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index)); + TaskStateSnapshot subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index)); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, @@ -2165,30 +2129,34 @@ public class CheckpointCoordinatorTest extends TestLogger { List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); for (int index = 0; index < jobVertex1.getParallelism(); index++) { - ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index); + StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false); - SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } for (int index = 0; index < jobVertex2.getParallelism(); index++) { - ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID2, index); + StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID2, index); KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false); - SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2284,17 +2252,20 @@ public class CheckpointCoordinatorTest extends TestLogger { StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); for (int index = 0; index < jobVertex1.getParallelism(); index++) { - ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index); + StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); KeyGroupsStateHandle keyGroupState = generateKeyGroupState( jobVertexID1, keyGroupPartitions1.get(index), false); - SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState); + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2302,17 +2273,19 @@ public class CheckpointCoordinatorTest extends TestLogger { for (int index = 0; index < jobVertex2.getParallelism(); index++) { - ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID2, index); + StreamStateHandle state = generateStateForVertex(jobVertexID2, index); KeyGroupsStateHandle keyGroupState = generateKeyGroupState( jobVertexID2, keyGroupPartitions2.get(index), false); - SubtaskState checkpointStateHandles = new SubtaskState(state, null, null, keyGroupState, null); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(state, null, null, keyGroupState, null); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2438,18 +2411,21 @@ public class CheckpointCoordinatorTest extends TestLogger { //vertex 1 for (int index = 0; index < jobVertex1.getParallelism(); index++) { - ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index); - ChainedStateHandle<OperatorStateHandle> opStateBackend = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false); + StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); + OperatorStateHandle opStateBackend = generatePartitionableStateHandle(jobVertexID1, index, 2, 8, false); KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false); KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), true); - SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, keyedStateRaw); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, keyedStateRaw); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState); + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2460,19 +2436,21 @@ public class CheckpointCoordinatorTest extends TestLogger { for (int index = 0; index < jobVertex2.getParallelism(); index++) { KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false); KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), true); - ChainedStateHandle<OperatorStateHandle> opStateBackend = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false); - ChainedStateHandle<OperatorStateHandle> opStateRaw = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, true); - expectedOpStatesBackend.add(opStateBackend); - expectedOpStatesRaw.add(opStateRaw); - SubtaskState checkpointStateHandles = - new SubtaskState(new ChainedStateHandle<>( - Collections.<StreamStateHandle>singletonList(null)), opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw); + OperatorStateHandle opStateBackend = generatePartitionableStateHandle(jobVertexID2, index, 2, 8, false); + OperatorStateHandle opStateRaw = generatePartitionableStateHandle(jobVertexID2, index, 2, 8, true); + expectedOpStatesBackend.add(new ChainedStateHandle<>(Collections.singletonList(opStateBackend))); + expectedOpStatesRaw.add(new ChainedStateHandle<>(Collections.singletonList(opStateRaw))); + + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(null, opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState); + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2506,27 +2484,37 @@ public class CheckpointCoordinatorTest extends TestLogger { List<List<Collection<OperatorStateHandle>>> actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism()); List<List<Collection<OperatorStateHandle>>> actualOpStatesRaw = new ArrayList<>(newJobVertex2.getParallelism()); for (int i = 0; i < newJobVertex2.getParallelism(); i++) { - KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false); - KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true); - TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); + List<OperatorID> operatorIDs = newJobVertex2.getOperatorIDs(); - ChainedStateHandle<StreamStateHandle> operatorState = taskStateHandles.getLegacyOperatorState(); - List<Collection<OperatorStateHandle>> opStateBackend = taskStateHandles.getManagedOperatorState(); - List<Collection<OperatorStateHandle>> opStateRaw = taskStateHandles.getRawOperatorState(); - Collection<KeyedStateHandle> keyedStateBackend = taskStateHandles.getManagedKeyedState(); - Collection<KeyedStateHandle> keyGroupStateRaw = taskStateHandles.getRawKeyedState(); + KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false); + KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true); - actualOpStatesBackend.add(opStateBackend); - actualOpStatesRaw.add(opStateRaw); - // the 'non partition state' is not null because it is recombined. - assertNotNull(operatorState); - for (int index = 0; index < operatorState.getLength(); index++) { - assertNull(operatorState.get(index)); + TaskStateSnapshot taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); + + final int headOpIndex = operatorIDs.size() - 1; + List<Collection<OperatorStateHandle>> allParallelManagedOpStates = new ArrayList<>(operatorIDs.size()); + List<Collection<OperatorStateHandle>> allParallelRawOpStates = new ArrayList<>(operatorIDs.size()); + + for (int idx = 0; idx < operatorIDs.size(); ++idx) { + OperatorID operatorID = operatorIDs.get(idx); + OperatorSubtaskState opState = taskStateHandles.getSubtaskStateByOperatorID(operatorID); + Assert.assertNull(opState.getLegacyOperatorState()); + Collection<OperatorStateHandle> opStateBackend = opState.getManagedOperatorState(); + Collection<OperatorStateHandle> opStateRaw = opState.getRawOperatorState(); + allParallelManagedOpStates.add(opStateBackend); + allParallelRawOpStates.add(opStateRaw); + if (idx == headOpIndex) { + Collection<KeyedStateHandle> keyedStateBackend = opState.getManagedKeyedState(); + Collection<KeyedStateHandle> keyGroupStateRaw = opState.getRawKeyedState(); + compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend); + compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw); + } } - compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend); - compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw); + actualOpStatesBackend.add(allParallelManagedOpStates); + actualOpStatesRaw.add(allParallelRawOpStates); } + comparePartitionableState(expectedOpStatesBackend, actualOpStatesBackend); comparePartitionableState(expectedOpStatesRaw, actualOpStatesRaw); } @@ -2578,14 +2566,11 @@ public class CheckpointCoordinatorTest extends TestLogger { operatorStates.put(id.f1, taskState); for (int index = 0; index < taskState.getParallelism(); index++) { StreamStateHandle subNonPartitionedState = - generateStateForVertex(id.f0, index) - .get(0); + generateStateForVertex(id.f0, index); OperatorStateHandle subManagedOperatorState = - generateChainedPartitionableStateHandle(id.f0, index, 2, 8, false) - .get(0); + generatePartitionableStateHandle(id.f0, index, 2, 8, false); OperatorStateHandle subRawOperatorState = - generateChainedPartitionableStateHandle(id.f0, index, 2, 8, true) - .get(0); + generatePartitionableStateHandle(id.f0, index, 2, 8, true); OperatorSubtaskState subtaskState = new OperatorSubtaskState(subNonPartitionedState, subManagedOperatorState, @@ -2707,57 +2692,75 @@ public class CheckpointCoordinatorTest extends TestLogger { for (int i = 0; i < newJobVertex1.getParallelism(); i++) { - TaskStateHandles taskStateHandles = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); - ChainedStateHandle<StreamStateHandle> actualSubNonPartitionedState = taskStateHandles.getLegacyOperatorState(); - List<Collection<OperatorStateHandle>> actualSubManagedOperatorState = taskStateHandles.getManagedOperatorState(); - List<Collection<OperatorStateHandle>> actualSubRawOperatorState = taskStateHandles.getRawOperatorState(); + final List<OperatorID> operatorIds = newJobVertex1.getOperatorIDs(); - assertNull(taskStateHandles.getManagedKeyedState()); - assertNull(taskStateHandles.getRawKeyedState()); + TaskStateSnapshot stateSnapshot = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); + + OperatorSubtaskState headOpState = stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1)); + assertTrue(headOpState.getManagedKeyedState().isEmpty()); + assertTrue(headOpState.getRawKeyedState().isEmpty()); // operator5 { int operatorIndexInChain = 2; - assertNull(actualSubNonPartitionedState.get(operatorIndexInChain)); - assertNull(actualSubManagedOperatorState.get(operatorIndexInChain)); - assertNull(actualSubRawOperatorState.get(operatorIndexInChain)); + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + + assertNull(opState.getLegacyOperatorState()); + assertTrue(opState.getManagedOperatorState().isEmpty()); + assertTrue(opState.getRawOperatorState().isEmpty()); } // operator1 { int operatorIndexInChain = 1; - ChainedStateHandle<StreamStateHandle> expectSubNonPartitionedState = generateStateForVertex(id1.f0, i); - ChainedStateHandle<OperatorStateHandle> expectedManagedOpState = generateChainedPartitionableStateHandle( + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + + StreamStateHandle expectSubNonPartitionedState = generateStateForVertex(id1.f0, i); + OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle( id1.f0, i, 2, 8, false); - ChainedStateHandle<OperatorStateHandle> expectedRawOpState = generateChainedPartitionableStateHandle( + OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle( id1.f0, i, 2, 8, true); assertTrue(CommonTestUtils.isSteamContentEqual( - expectSubNonPartitionedState.get(0).openInputStream(), - actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream())); - - assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.get(0).openInputStream(), - actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); - - assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.get(0).openInputStream(), - actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); + expectSubNonPartitionedState.openInputStream(), + opState.getLegacyOperatorState().openInputStream())); + + Collection<OperatorStateHandle> managedOperatorState = opState.getManagedOperatorState(); + assertEquals(1, managedOperatorState.size()); + assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.openInputStream(), + managedOperatorState.iterator().next().openInputStream())); + + Collection<OperatorStateHandle> rawOperatorState = opState.getRawOperatorState(); + assertEquals(1, rawOperatorState.size()); + assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.openInputStream(), + rawOperatorState.iterator().next().openInputStream())); } // operator2 { int operatorIndexInChain = 0; - ChainedStateHandle<StreamStateHandle> expectSubNonPartitionedState = generateStateForVertex(id2.f0, i); - ChainedStateHandle<OperatorStateHandle> expectedManagedOpState = generateChainedPartitionableStateHandle( + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + + StreamStateHandle expectSubNonPartitionedState = generateStateForVertex(id2.f0, i); + OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle( id2.f0, i, 2, 8, false); - ChainedStateHandle<OperatorStateHandle> expectedRawOpState = generateChainedPartitionableStateHandle( + OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle( id2.f0, i, 2, 8, true); - assertTrue(CommonTestUtils.isSteamContentEqual(expectSubNonPartitionedState.get(0).openInputStream(), - actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream())); - - assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.get(0).openInputStream(), - actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); - - assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.get(0).openInputStream(), - actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); + assertTrue(CommonTestUtils.isSteamContentEqual( + expectSubNonPartitionedState.openInputStream(), + opState.getLegacyOperatorState().openInputStream())); + + Collection<OperatorStateHandle> managedOperatorState = opState.getManagedOperatorState(); + assertEquals(1, managedOperatorState.size()); + assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.openInputStream(), + managedOperatorState.iterator().next().openInputStream())); + + Collection<OperatorStateHandle> rawOperatorState = opState.getRawOperatorState(); + assertEquals(1, rawOperatorState.size()); + assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.openInputStream(), + rawOperatorState.iterator().next().openInputStream())); } } @@ -2765,38 +2768,48 @@ public class CheckpointCoordinatorTest extends TestLogger { List<List<Collection<OperatorStateHandle>>> actualRawOperatorStates = new ArrayList<>(newJobVertex2.getParallelism()); for (int i = 0; i < newJobVertex2.getParallelism(); i++) { - TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); + + final List<OperatorID> operatorIds = newJobVertex2.getOperatorIDs(); + + TaskStateSnapshot stateSnapshot = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); // operator 3 { int operatorIndexInChain = 1; + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + List<Collection<OperatorStateHandle>> actualSubManagedOperatorState = new ArrayList<>(1); - actualSubManagedOperatorState.add(taskStateHandles.getManagedOperatorState().get(operatorIndexInChain)); + actualSubManagedOperatorState.add(opState.getManagedOperatorState()); List<Collection<OperatorStateHandle>> actualSubRawOperatorState = new ArrayList<>(1); - actualSubRawOperatorState.add(taskStateHandles.getRawOperatorState().get(operatorIndexInChain)); + actualSubRawOperatorState.add(opState.getRawOperatorState()); actualManagedOperatorStates.add(actualSubManagedOperatorState); actualRawOperatorStates.add(actualSubRawOperatorState); - assertNull(taskStateHandles.getLegacyOperatorState().get(operatorIndexInChain)); + assertNull(opState.getLegacyOperatorState()); } // operator 6 { int operatorIndexInChain = 0; - assertNull(taskStateHandles.getManagedOperatorState().get(operatorIndexInChain)); - assertNull(taskStateHandles.getRawOperatorState().get(operatorIndexInChain)); - assertNull(taskStateHandles.getLegacyOperatorState().get(operatorIndexInChain)); + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + assertNull(opState.getLegacyOperatorState()); + assertTrue(opState.getManagedOperatorState().isEmpty()); + assertTrue(opState.getRawOperatorState().isEmpty()); } KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), false); KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), true); + OperatorSubtaskState headOpState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1)); - Collection<KeyedStateHandle> keyedStateBackend = taskStateHandles.getManagedKeyedState(); - Collection<KeyedStateHandle> keyGroupStateRaw = taskStateHandles.getRawKeyedState(); + Collection<KeyedStateHandle> keyedStateBackend = headOpState.getManagedKeyedState(); + Collection<KeyedStateHandle> keyGroupStateRaw = headOpState.getRawKeyedState(); compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend); @@ -2974,19 +2987,50 @@ public class CheckpointCoordinatorTest extends TestLogger { return new Tuple2<>(allSerializedValuesConcatenated, offsets); } - public static ChainedStateHandle<StreamStateHandle> generateStateForVertex( + public static StreamStateHandle generateStateForVertex( JobVertexID jobVertexID, int index) throws IOException { Random random = new Random(jobVertexID.hashCode() + index); int value = random.nextInt(); - return generateChainedStateHandle(value); + return generateStreamStateHandle(value); + } + + public static StreamStateHandle generateStreamStateHandle(Serializable value) throws IOException { + return TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()), value); } public static ChainedStateHandle<StreamStateHandle> generateChainedStateHandle( Serializable value) throws IOException { return ChainedStateHandle.wrapSingleHandle( - TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()), value)); + generateStreamStateHandle(value)); + } + + public static OperatorStateHandle generatePartitionableStateHandle( + JobVertexID jobVertexID, + int index, + int namedStates, + int partitionsPerState, + boolean rawState) throws IOException { + + Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates); + + for (int i = 0; i < namedStates; ++i) { + List<Integer> testStatesLists = new ArrayList<>(partitionsPerState); + // generate state + int seed = jobVertexID.hashCode() * index + i * namedStates; + if (rawState) { + seed = (seed + 1) * 31; + } + Random random = new Random(seed); + for (int j = 0; j < partitionsPerState; ++j) { + int simulatedStateValue = random.nextInt(); + testStatesLists.add(simulatedStateValue); + } + statesListsMap.put("state-" + i, testStatesLists); + } + + return generatePartitionableStateHandle(statesListsMap); } public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle( @@ -3013,11 +3057,11 @@ public class CheckpointCoordinatorTest extends TestLogger { statesListsMap.put("state-" + i, testStatesLists); } - return generateChainedPartitionableStateHandle(statesListsMap); + return ChainedStateHandle.wrapSingleHandle(generatePartitionableStateHandle(statesListsMap)); } - private static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle( - Map<String, List<? extends Serializable>> states) throws IOException { + private static OperatorStateHandle generatePartitionableStateHandle( + Map<String, List<? extends Serializable>> states) throws IOException { List<List<? extends Serializable>> namedStateSerializables = new ArrayList<>(states.size()); @@ -3032,20 +3076,18 @@ public class CheckpointCoordinatorTest extends TestLogger { int idx = 0; for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) { offsetsMap.put( - entry.getKey(), - new OperatorStateHandle.StateMetaInfo( - serializationWithOffsets.f1.get(idx), - OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); + entry.getKey(), + new OperatorStateHandle.StateMetaInfo( + serializationWithOffsets.f1.get(idx), + OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); ++idx; } ByteStreamStateHandle streamStateHandle = new TestByteStreamStateHandleDeepCompare( - String.valueOf(UUID.randomUUID()), - serializationWithOffsets.f0); + String.valueOf(UUID.randomUUID()), + serializationWithOffsets.f0); - OperatorStateHandle operatorStateHandle = - new OperatorStateHandle(offsetsMap, streamStateHandle); - return ChainedStateHandle.wrapSingleHandle(operatorStateHandle); + return new OperatorStateHandle(offsetsMap, streamStateHandle); } static ExecutionJobVertex mockExecutionJobVertex( @@ -3139,24 +3181,23 @@ public class CheckpointCoordinatorTest extends TestLogger { return vertex; } - static SubtaskState mockSubtaskState( + static TaskStateSnapshot mockSubtaskState( JobVertexID jobVertexID, int index, KeyGroupRange keyGroupRange) throws IOException { - ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID, index); - ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID, index, 2, 8, false); + StreamStateHandle nonPartitionedState = generateStateForVertex(jobVertexID, index); + OperatorStateHandle partitionableState = generatePartitionableStateHandle(jobVertexID, index, 2, 8, false); KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID, keyGroupRange, false); - SubtaskState subtaskState = mock(SubtaskState.class, withSettings().serializable()); + TaskStateSnapshot subtaskStates = spy(new TaskStateSnapshot()); + OperatorSubtaskState subtaskState = spy(new OperatorSubtaskState( + nonPartitionedState, partitionableState, null, partitionedKeyGroupState, null) + ); - doReturn(nonPartitionedState).when(subtaskState).getLegacyOperatorState(); - doReturn(partitionableState).when(subtaskState).getManagedOperatorState(); - doReturn(null).when(subtaskState).getRawOperatorState(); - doReturn(partitionedKeyGroupState).when(subtaskState).getManagedKeyedState(); - doReturn(null).when(subtaskState).getRawKeyedState(); + subtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID), subtaskState); - return subtaskState; + return subtaskStates; } public static void verifyStateRestore( @@ -3165,27 +3206,27 @@ public class CheckpointCoordinatorTest extends TestLogger { for (int i = 0; i < executionJobVertex.getParallelism(); i++) { - TaskStateHandles taskStateHandles = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); + final List<OperatorID> operatorIds = executionJobVertex.getOperatorIDs(); - ChainedStateHandle<StreamStateHandle> expectNonPartitionedState = generateStateForVertex(jobVertexID, i); - ChainedStateHandle<StreamStateHandle> actualNonPartitionedState = taskStateHandles.getLegacyOperatorState(); + TaskStateSnapshot stateSnapshot = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); + + OperatorSubtaskState operatorState = stateSnapshot.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID)); + + StreamStateHandle expectNonPartitionedState = generateStateForVertex(jobVertexID, i); assertTrue(CommonTestUtils.isSteamContentEqual( - expectNonPartitionedState.get(0).openInputStream(), - actualNonPartitionedState.get(0).openInputStream())); + expectNonPartitionedState.openInputStream(), + operatorState.getLegacyOperatorState().openInputStream())); ChainedStateHandle<OperatorStateHandle> expectedOpStateBackend = generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false); - List<Collection<OperatorStateHandle>> actualPartitionableState = taskStateHandles.getManagedOperatorState(); - assertTrue(CommonTestUtils.isSteamContentEqual( expectedOpStateBackend.get(0).openInputStream(), - actualPartitionableState.get(0).iterator().next().openInputStream())); + operatorState.getManagedOperatorState().iterator().next().openInputStream())); KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState( jobVertexID, keyGroupPartitions.get(i), false); - Collection<KeyedStateHandle> actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState(); - compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), actualPartitionedKeyGroupState); + compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), operatorState.getManagedKeyedState()); } } @@ -3632,17 +3673,4 @@ public class CheckpointCoordinatorTest extends TestLogger { "The latest completed (proper) checkpoint should have been added to the completed checkpoint store.", completedCheckpointStore.getLatestCheckpoint().getCheckpointID() == checkpointIDCounter.getLast()); } - - private static final class SpyInjectingOperatorState extends OperatorState { - - private static final long serialVersionUID = -4004437428483663815L; - - public SpyInjectingOperatorState(OperatorID taskID, int parallelism, int maxParallelism) { - super(taskID, parallelism, maxParallelism); - } - - public void putState(int subtaskIndex, OperatorSubtaskState subtaskState) { - super.putState(subtaskIndex, spy(subtaskState)); - } - } } http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java index 7d24568..6ce071b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java @@ -34,18 +34,18 @@ import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.util.SerializableObject; + import org.hamcrest.BaseMatcher; import org.hamcrest.Description; import org.junit.Test; import org.mockito.Mockito; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -118,10 +118,20 @@ public class CheckpointStateRestoreTest { PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next(); final long checkpointId = pending.getCheckpointId(); - SubtaskState checkpointStateHandles = new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles)); + final TaskStateSnapshot subtaskStates = new TaskStateSnapshot(); + + subtaskStates.putSubtaskStateByOperatorID( + OperatorID.fromJobVertexID(statefulId), + new OperatorSubtaskState( + serializedState.get(0), + Collections.<OperatorStateHandle>emptyList(), + Collections.<OperatorStateHandle>emptyList(), + Collections.singletonList(serializedKeyGroupStates), + Collections.<KeyedStateHandle>emptyList())); + + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId)); @@ -133,33 +143,26 @@ public class CheckpointStateRestoreTest { // verify that each stateful vertex got the state - final TaskStateHandles taskStateHandles = new TaskStateHandles( - serializedState, - Collections.<Collection<OperatorStateHandle>>singletonList(null), - Collections.<Collection<OperatorStateHandle>>singletonList(null), - Collections.singletonList(serializedKeyGroupStates), - null); - - BaseMatcher<TaskStateHandles> matcher = new BaseMatcher<TaskStateHandles>() { + BaseMatcher<TaskStateSnapshot> matcher = new BaseMatcher<TaskStateSnapshot>() { @Override public boolean matches(Object o) { - if (o instanceof TaskStateHandles) { - return o.equals(taskStateHandles); + if (o instanceof TaskStateSnapshot) { + return Objects.equals(o, subtaskStates); } return false; } @Override public void describeTo(Description description) { - description.appendValue(taskStateHandles); + description.appendValue(subtaskStates); } }; verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher)); verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher)); verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher)); - verify(statelessExec1, times(0)).setInitialState(Mockito.<TaskStateHandles>any()); - verify(statelessExec2, times(0)).setInitialState(Mockito.<TaskStateHandles>any()); + verify(statelessExec1, times(0)).setInitialState(Mockito.<TaskStateSnapshot>any()); + verify(statelessExec2, times(0)).setInitialState(Mockito.<TaskStateSnapshot>any()); } catch (Exception e) { e.printStackTrace(); @@ -250,9 +253,9 @@ public class CheckpointStateRestoreTest { Map<OperatorID, OperatorState> checkpointTaskStates = new HashMap<>(); { OperatorState taskState = new OperatorState(operatorId1, 3, 3); - taskState.putState(0, new OperatorSubtaskState(serializedState, null, null, null, null)); - taskState.putState(1, new OperatorSubtaskState(serializedState, null, null, null, null)); - taskState.putState(2, new OperatorSubtaskState(serializedState, null, null, null, null)); + taskState.putState(0, new OperatorSubtaskState(serializedState)); + taskState.putState(1, new OperatorSubtaskState(serializedState)); + taskState.putState(2, new OperatorSubtaskState(serializedState)); checkpointTaskStates.put(operatorId1, taskState); } @@ -279,7 +282,7 @@ public class CheckpointStateRestoreTest { // There is no task for this { OperatorState taskState = new OperatorState(newOperatorID, 1, 1); - taskState.putState(0, new OperatorSubtaskState(serializedState, null, null, null, null)); + taskState.putState(0, new OperatorSubtaskState(serializedState)); checkpointTaskStates.put(newOperatorID, taskState); } http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java index 1fe4e65..320dc2d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java @@ -331,7 +331,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { boolean discarded; public TestOperatorSubtaskState() { - super(null, null, null, null, null); + super(); this.registered = false; this.discarded = false; } http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java index 7d103d0..7ebb49a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java @@ -324,7 +324,7 @@ public class PendingCheckpointTest { @Test public void testNonNullSubtaskStateLeadsToStatefulTask() throws Exception { PendingCheckpoint pending = createPendingCheckpoint(CheckpointProperties.forStandardCheckpoint(), null); - pending.acknowledgeTask(ATTEMPT_ID, mock(SubtaskState.class), mock(CheckpointMetrics.class)); + pending.acknowledgeTask(ATTEMPT_ID, mock(TaskStateSnapshot.class), mock(CheckpointMetrics.class)); Assert.assertFalse(pending.getOperatorStates().isEmpty()); } http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java index 36c9cad..9ed4851 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.testutils.CommonTestUtils; import org.apache.flink.runtime.blob.BlobKey; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.JobInformation; @@ -30,7 +31,6 @@ import org.apache.flink.runtime.executiongraph.TaskInformation; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.operators.BatchTask; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.SerializedValue; import org.junit.Test; @@ -73,7 +73,7 @@ public class TaskDeploymentDescriptorTest { final SerializedValue<TaskInformation> serializedJobVertexInformation = new SerializedValue<>(new TaskInformation( vertexID, taskName, currentNumberOfSubtasks, numberOfKeyGroups, invokableClass.getName(), taskConfiguration)); final int targetSlotNumber = 47; - final TaskStateHandles taskStateHandles = new TaskStateHandles(); + final TaskStateSnapshot taskStateHandles = new TaskStateSnapshot(); final TaskDeploymentDescriptor orig = new TaskDeploymentDescriptor( serializedJobInformation, http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java index 0eed90d..c9b7a40 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.time.Time; import org.apache.flink.configuration.Configuration; import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; import org.apache.flink.runtime.checkpoint.StandaloneCheckpointRecoveryFactory; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.clusterframework.types.ResourceProfile; @@ -38,7 +39,6 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobmanager.slots.AllocatedSlot; import org.apache.flink.runtime.jobmanager.slots.SlotOwner; import org.apache.flink.runtime.jobmanager.slots.TaskManagerGateway; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.testtasks.NoOpInvokable; @@ -51,8 +51,10 @@ import java.net.InetAddress; import java.util.Iterator; import java.util.concurrent.TimeUnit; -import static org.mockito.Mockito.*; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; /** * Tests that the execution vertex handles locality preferences well. @@ -169,7 +171,7 @@ public class ExecutionVertexLocalityTest extends TestLogger { // target state ExecutionVertex target = graph.getAllVertices().get(targetVertexId).getTaskVertices()[i]; - target.getCurrentExecutionAttempt().setInitialState(mock(TaskStateHandles.class)); + target.getCurrentExecutionAttempt().setInitialState(mock(TaskStateSnapshot.class)); } // validate that the target vertices have the state's location as the location preference http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java index a63b02d..23f0a38 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java @@ -18,16 +18,6 @@ package org.apache.flink.runtime.jobmanager; -import akka.actor.ActorRef; -import akka.actor.ActorSystem; -import akka.actor.Identify; -import akka.actor.PoisonPill; -import akka.actor.Props; -import akka.japi.pf.FI; -import akka.japi.pf.ReceiveBuilder; -import akka.pattern.Patterns; -import akka.testkit.CallingThreadDispatcher; -import akka.testkit.JavaTestKit; import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; @@ -44,8 +34,9 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory; import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.StandaloneCheckpointIDCounter; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager; import org.apache.flink.runtime.executiongraph.restart.FixedDelayRestartStrategy; @@ -59,6 +50,7 @@ import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings; @@ -69,9 +61,6 @@ import org.apache.flink.runtime.leaderelection.TestingLeaderElectionService; import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService; import org.apache.flink.runtime.messages.JobManagerMessages; import org.apache.flink.runtime.metrics.MetricRegistry; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.taskmanager.TaskManager; import org.apache.flink.runtime.testingUtils.TestingJobManager; @@ -83,23 +72,24 @@ import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore; import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare; import org.apache.flink.util.InstantiationUtil; - import org.apache.flink.util.TestLogger; + +import akka.actor.ActorRef; +import akka.actor.ActorSystem; +import akka.actor.Identify; +import akka.actor.PoisonPill; +import akka.actor.Props; +import akka.japi.pf.FI; +import akka.japi.pf.ReceiveBuilder; +import akka.pattern.Patterns; +import akka.testkit.CallingThreadDispatcher; +import akka.testkit.JavaTestKit; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import scala.Int; -import scala.Option; -import scala.PartialFunction; -import scala.concurrent.Await; -import scala.concurrent.Future; -import scala.concurrent.duration.Deadline; -import scala.concurrent.duration.FiniteDuration; -import scala.runtime.BoxedUnit; - import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -113,6 +103,15 @@ import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import scala.Int; +import scala.Option; +import scala.PartialFunction; +import scala.concurrent.Await; +import scala.concurrent.Future; +import scala.concurrent.duration.Deadline; +import scala.concurrent.duration.FiniteDuration; +import scala.runtime.BoxedUnit; + import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; @@ -552,10 +551,10 @@ public class JobManagerHARecoveryTest extends TestLogger { @Override public void setInitialState( - TaskStateHandles taskStateHandles) throws Exception { + TaskStateSnapshot taskStateHandles) throws Exception { int subtaskIndex = getIndexInSubtaskGroup(); if (subtaskIndex < recoveredStates.length) { - try (FSDataInputStream in = taskStateHandles.getLegacyOperatorState().get(0).openInputStream()) { + try (FSDataInputStream in = taskStateHandles.getSubtaskStateMappings().iterator().next().getValue().getLegacyOperatorState().openInputStream()) { recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(in, getUserCodeClassLoader()); } } @@ -567,10 +566,11 @@ public class JobManagerHARecoveryTest extends TestLogger { String.valueOf(UUID.randomUUID()), InstantiationUtil.serializeObject(checkpointMetaData.getCheckpointId())); - ChainedStateHandle<StreamStateHandle> chainedStateHandle = - new ChainedStateHandle<StreamStateHandle>(Collections.singletonList(byteStreamStateHandle)); - SubtaskState checkpointStateHandles = - new SubtaskState(chainedStateHandle, null, null, null, null); + TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot(); + checkpointStateHandles.putSubtaskStateByOperatorID( + OperatorID.fromJobVertexID(getEnvironment().getJobVertexId()), + new OperatorSubtaskState(byteStreamStateHandle) + ); getEnvironment().acknowledgeCheckpoint( checkpointMetaData.getCheckpointId(), http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java index bc420cc..d022cdc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java @@ -24,14 +24,17 @@ import org.apache.flink.core.testutils.CommonTestUtils; import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete; import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.StreamStateHandle; + import org.junit.Test; import java.io.IOException; @@ -68,13 +71,17 @@ public class CheckpointMessagesTest { KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42); - SubtaskState checkpointStateHandles = - new SubtaskState( - CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()), - CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new JobVertexID(), 0, 2, 8, false), - null, - CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())), - null); + TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot(); + checkpointStateHandles.putSubtaskStateByOperatorID( + new OperatorID(), + new OperatorSubtaskState( + CheckpointCoordinatorTest.generateStreamStateHandle(new MyHandle()), + CheckpointCoordinatorTest.generatePartitionableStateHandle(new JobVertexID(), 0, 2, 8, false), + null, + CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())), + null + ) + ); AcknowledgeCheckpoint withState = new AcknowledgeCheckpoint( new JobID(), http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java index 851fa96..8ed06b2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java @@ -26,7 +26,7 @@ import org.apache.flink.core.fs.Path; import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -156,7 +156,7 @@ public class DummyEnvironment implements Environment { } @Override - public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState) { + public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) { } @Override http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java index 4f0242e..7514cc4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java @@ -27,7 +27,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -50,8 +50,8 @@ import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.types.Record; import org.apache.flink.util.MutableObjectIterator; - import org.apache.flink.util.Preconditions; + import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -354,7 +354,7 @@ public class MockEnvironment implements Environment { } @Override - public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState) { + public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) { throw new UnsupportedOperationException(); } http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java index c6d2fec..085a386 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java @@ -27,6 +27,7 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; @@ -49,7 +50,6 @@ import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.util.SerializedValue; @@ -187,7 +187,7 @@ public class TaskAsyncCallTest { Collections.<ResultPartitionDeploymentDescriptor>emptyList(), Collections.<InputGateDeploymentDescriptor>emptyList(), 0, - new TaskStateHandles(), + new TaskStateSnapshot(), mock(MemoryManager.class), mock(IOManager.class), networkEnvironment, @@ -228,7 +228,7 @@ public class TaskAsyncCallTest { } @Override - public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {} + public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception {} @Override public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) {