This is an automated email from the ASF dual-hosted git repository. pnowojski pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit db53a04cdec88a9b74e17b028a3a6922e14e45d4 Author: Aleksey Pak <[email protected]> AuthorDate: Thu Jun 20 17:08:52 2019 +0200 [hotfix][tests] Rewrite StreamTaskTest without reflection based fields setting --- .../streaming/runtime/tasks/StreamTaskTest.java | 522 ++++++++++----------- .../flink/streaming/util/StreamTaskUtil.java | 40 ++ 2 files changed, 292 insertions(+), 270 deletions(-) diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java index e2171b7..46b5389 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java @@ -20,13 +20,10 @@ package org.apache.flink.streaming.runtime.tasks; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; -import org.apache.flink.api.common.TaskInfo; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.CheckpointingOptions; import org.apache.flink.configuration.Configuration; -import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.core.testutils.OneShotLatch; -import org.apache.flink.mock.Whitebox; import org.apache.flink.runtime.blob.BlobCacheService; import org.apache.flink.runtime.blob.PermanentBlobCache; import org.apache.flink.runtime.blob.TransientBlobCache; @@ -41,6 +38,7 @@ import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; +import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager; @@ -53,7 +51,6 @@ import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder; import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; -import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; @@ -67,7 +64,6 @@ import org.apache.flink.runtime.query.KvStateRegistry; import org.apache.flink.runtime.shuffle.ShuffleEnvironment; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AbstractStateBackend; -import org.apache.flink.runtime.state.CheckpointStorage; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.DoneFuture; import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider; @@ -83,7 +79,6 @@ import org.apache.flink.runtime.state.TaskLocalStateStoreImpl; import org.apache.flink.runtime.state.TaskStateManager; import org.apache.flink.runtime.state.TaskStateManagerImpl; import org.apache.flink.runtime.state.TestTaskStateManager; -import org.apache.flink.runtime.state.memory.MemoryBackendCheckpointStorage; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.runtime.taskexecutor.KvStateService; import org.apache.flink.runtime.taskexecutor.PartitionProducerStateChecker; @@ -93,7 +88,6 @@ import org.apache.flink.runtime.taskmanager.NoOpTaskManagerActions; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.taskmanager.TaskExecutionState; import org.apache.flink.runtime.taskmanager.TaskManagerActions; -import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.functions.source.SourceFunction; @@ -112,9 +106,12 @@ import org.apache.flink.streaming.runtime.streamstatus.StreamStatusMaintainer; import org.apache.flink.util.CloseableIterable; import org.apache.flink.util.SerializedValue; import org.apache.flink.util.TestLogger; +import org.apache.flink.util.function.SupplierWithException; import org.junit.Assert; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.invocation.InvocationOnMock; @@ -130,20 +127,19 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.RunnableFuture; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import static org.apache.flink.runtime.concurrent.Executors.newDirectExecutorService; +import static org.apache.flink.streaming.util.StreamTaskUtil.waitTaskIsRunning; import static org.hamcrest.Matchers.everyItem; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; @@ -152,7 +148,6 @@ import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; -import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -174,6 +169,9 @@ public class StreamTaskTest extends TestLogger { private static OneShotLatch syncLatch; + @Rule + public final Timeout timeoutPerTest = Timeout.seconds(30); + /** * This test checks that cancel calls that are issued before the operator is * instantiated still lead to proper canceling. @@ -311,63 +309,36 @@ public class StreamTaskTest extends TestLogger { @Test public void testDecliningCheckpointStreamOperator() throws Exception { - final long checkpointId = 42L; - final long timestamp = 1L; - - TaskInfo mockTaskInfo = mock(TaskInfo.class); - when(mockTaskInfo.getTaskNameWithSubtasks()).thenReturn("foobar"); - when(mockTaskInfo.getIndexOfThisSubtask()).thenReturn(0); CheckpointExceptionHandlerTest.DeclineDummyEnvironment declineDummyEnvironment = new CheckpointExceptionHandlerTest.DeclineDummyEnvironment(); - StreamTask<?, ?> streamTask = new EmptyStreamTask(declineDummyEnvironment); - CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp); - - // mock the operators - StreamOperator<?> streamOperator1 = mock(StreamOperator.class); - StreamOperator<?> streamOperator2 = mock(StreamOperator.class); - StreamOperator<?> streamOperator3 = mock(StreamOperator.class); - // mock the returned snapshots OperatorSnapshotFutures operatorSnapshotResult1 = mock(OperatorSnapshotFutures.class); OperatorSnapshotFutures operatorSnapshotResult2 = mock(OperatorSnapshotFutures.class); final Exception testException = new Exception("Test exception"); - when(streamOperator1.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class), any(CheckpointStreamFactory.class))).thenReturn(operatorSnapshotResult1); - when(streamOperator2.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class), any(CheckpointStreamFactory.class))).thenReturn(operatorSnapshotResult2); - when(streamOperator3.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class), any(CheckpointStreamFactory.class))).thenThrow(testException); - - OperatorID operatorID1 = new OperatorID(); - OperatorID operatorID2 = new OperatorID(); - OperatorID operatorID3 = new OperatorID(); - when(streamOperator1.getOperatorID()).thenReturn(operatorID1); - when(streamOperator2.getOperatorID()).thenReturn(operatorID2); - when(streamOperator3.getOperatorID()).thenReturn(operatorID3); - - // set up the task + RunningTask<MockStreamTask> task = runTask(() -> new MockStreamTask( + declineDummyEnvironment, + operatorChain( + streamOperatorWithSnapshot(operatorSnapshotResult1), + streamOperatorWithSnapshot(operatorSnapshotResult2), + streamOperatorWithSnapshotException(testException)))); + MockStreamTask streamTask = task.streamTask; - StreamOperator<?>[] streamOperators = {streamOperator1, streamOperator2, streamOperator3}; + waitTaskIsRunning(streamTask, task.invocationFuture); - OperatorChain<Void, AbstractStreamOperator<Void>> operatorChain = mock(OperatorChain.class); - when(operatorChain.getAllOperators()).thenReturn(streamOperators); - - Whitebox.setInternalState(streamTask, "isRunning", true); - Whitebox.setInternalState(streamTask, "lock", new Object()); - Whitebox.setInternalState(streamTask, "operatorChain", operatorChain); - Whitebox.setInternalState(streamTask, "cancelables", new CloseableRegistry()); - Whitebox.setInternalState(streamTask, "configuration", new StreamConfig(new Configuration())); - Whitebox.setInternalState(streamTask, "checkpointStorage", new MemoryBackendCheckpointStorage(new JobID(), null, null, Integer.MAX_VALUE)); - - CheckpointExceptionHandlerFactory checkpointExceptionHandlerFactory = new CheckpointExceptionHandlerFactory(); - CheckpointExceptionHandler checkpointExceptionHandler = - checkpointExceptionHandlerFactory.createCheckpointExceptionHandler(declineDummyEnvironment); - Whitebox.setInternalState(streamTask, "checkpointExceptionHandler", checkpointExceptionHandler); + streamTask.triggerCheckpoint( + new CheckpointMetaData(42L, 1L), + CheckpointOptions.forCheckpointWithDefaultLocation(), + false); - streamTask.triggerCheckpoint(checkpointMetaData, CheckpointOptions.forCheckpointWithDefaultLocation(), false); assertEquals(testException, declineDummyEnvironment.getLastDeclinedCheckpointCause()); verify(operatorSnapshotResult1).cancel(); verify(operatorSnapshotResult2).cancel(); + + task.streamTask.finishInput(); + task.waitForTaskCompletion(false); } /** @@ -376,17 +347,6 @@ public class StreamTaskTest extends TestLogger { */ @Test public void testFailingAsyncCheckpointRunnable() throws Exception { - final long checkpointId = 42L; - final long timestamp = 1L; - - MockEnvironment mockEnvironment = new MockEnvironmentBuilder().build(); - StreamTask<?, ?> streamTask = spy(new EmptyStreamTask(mockEnvironment)); - CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp); - - // mock the operators - StreamOperator<?> streamOperator1 = mock(StreamOperator.class); - StreamOperator<?> streamOperator2 = mock(StreamOperator.class); - StreamOperator<?> streamOperator3 = mock(StreamOperator.class); // mock the new state operator snapshots OperatorSnapshotFutures operatorSnapshotResult1 = mock(OperatorSnapshotFutures.class); @@ -398,43 +358,41 @@ public class StreamTaskTest extends TestLogger { when(operatorSnapshotResult3.getOperatorStateRawFuture()).thenReturn(failingFuture); - when(streamOperator1.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class), any(CheckpointStreamFactory.class))).thenReturn(operatorSnapshotResult1); - when(streamOperator2.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class), any(CheckpointStreamFactory.class))).thenReturn(operatorSnapshotResult2); - when(streamOperator3.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class), any(CheckpointStreamFactory.class))).thenReturn(operatorSnapshotResult3); - - OperatorID operatorID1 = new OperatorID(); - OperatorID operatorID2 = new OperatorID(); - OperatorID operatorID3 = new OperatorID(); - when(streamOperator1.getOperatorID()).thenReturn(operatorID1); - when(streamOperator2.getOperatorID()).thenReturn(operatorID2); - when(streamOperator3.getOperatorID()).thenReturn(operatorID3); - - StreamOperator<?>[] streamOperators = {streamOperator1, streamOperator2, streamOperator3}; - - OperatorChain<Void, AbstractStreamOperator<Void>> operatorChain = mock(OperatorChain.class); - when(operatorChain.getAllOperators()).thenReturn(streamOperators); - - Whitebox.setInternalState(streamTask, "isRunning", true); - Whitebox.setInternalState(streamTask, "lock", new Object()); - Whitebox.setInternalState(streamTask, "operatorChain", operatorChain); - Whitebox.setInternalState(streamTask, "cancelables", new CloseableRegistry()); - Whitebox.setInternalState(streamTask, "asyncOperationsThreadPool", newDirectExecutorService()); - Whitebox.setInternalState(streamTask, "configuration", new StreamConfig(new Configuration())); - Whitebox.setInternalState(streamTask, "checkpointStorage", new MemoryBackendCheckpointStorage(new JobID(), null, null, Integer.MAX_VALUE)); - - CheckpointExceptionHandlerFactory checkpointExceptionHandlerFactory = new CheckpointExceptionHandlerFactory(); - CheckpointExceptionHandler checkpointExceptionHandler = - checkpointExceptionHandlerFactory.createCheckpointExceptionHandler(mockEnvironment); - Whitebox.setInternalState(streamTask, "checkpointExceptionHandler", checkpointExceptionHandler); + try (MockEnvironment mockEnvironment = new MockEnvironmentBuilder().build()) { + RunningTask<MockStreamTask> task = runTask(() -> new MockStreamTask( + mockEnvironment, + operatorChain( + streamOperatorWithSnapshot(operatorSnapshotResult1), + streamOperatorWithSnapshot(operatorSnapshotResult2), + streamOperatorWithSnapshot(operatorSnapshotResult3)))); + + MockStreamTask streamTask = task.streamTask; + + waitTaskIsRunning(streamTask, task.invocationFuture); + + mockEnvironment.setExpectedExternalFailureCause(Throwable.class); + streamTask.triggerCheckpoint( + new CheckpointMetaData(42L, 1L), + CheckpointOptions.forCheckpointWithDefaultLocation(), + false); + + // wait for the completion of the async task + ExecutorService executor = streamTask.getAsyncOperationsThreadPool(); + executor.shutdown(); + if (!executor.awaitTermination(10000L, TimeUnit.MILLISECONDS)) { + fail("Executor did not shut down within the given timeout. This indicates that the " + + "checkpointing did not resume."); + } - mockEnvironment.setExpectedExternalFailureCause(Throwable.class); - streamTask.triggerCheckpoint(checkpointMetaData, CheckpointOptions.forCheckpointWithDefaultLocation(), false); + assertTrue(mockEnvironment.getActualExternalFailureCause().isPresent()); - verify(streamTask).handleAsyncException(anyString(), any(Throwable.class)); + verify(operatorSnapshotResult1).cancel(); + verify(operatorSnapshotResult2).cancel(); + verify(operatorSnapshotResult3).cancel(); - verify(operatorSnapshotResult1).cancel(); - verify(operatorSnapshotResult2).cancel(); - verify(operatorSnapshotResult3).cancel(); + streamTask.finishInput(); + task.waitForTaskCompletion(false); + } } /** @@ -447,8 +405,6 @@ public class StreamTaskTest extends TestLogger { */ @Test public void testAsyncCheckpointingConcurrentCloseAfterAcknowledge() throws Exception { - final long checkpointId = 42L; - final long timestamp = 1L; final OneShotLatch acknowledgeCheckpointLatch = new OneShotLatch(); final OneShotLatch completeAcknowledge = new OneShotLatch(); @@ -478,17 +434,6 @@ public class StreamTaskTest extends TestLogger { null, checkpointResponder); - MockEnvironment mockEnvironment = new MockEnvironmentBuilder() - .setTaskName("mock-task") - .setTaskStateManager(taskStateManager) - .build(); - - StreamTask<?, ?> streamTask = new EmptyStreamTask(mockEnvironment); - CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp); - - StreamOperator<?> streamOperator = mock(StreamOperator.class); - when(streamOperator.getOperatorID()).thenReturn(new OperatorID(42, 42)); - KeyedStateHandle managedKeyedStateHandle = mock(KeyedStateHandle.class); KeyedStateHandle rawKeyedStateHandle = mock(KeyedStateHandle.class); OperatorStateHandle managedOperatorStateHandle = mock(OperatorStreamStateHandle.class); @@ -500,62 +445,64 @@ public class StreamTaskTest extends TestLogger { DoneFuture.of(SnapshotResult.of(managedOperatorStateHandle)), DoneFuture.of(SnapshotResult.of(rawOperatorStateHandle))); - when(streamOperator.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class), any(CheckpointStreamFactory.class))).thenReturn(operatorSnapshotResult); - - StreamOperator<?>[] streamOperators = {streamOperator}; + try (MockEnvironment mockEnvironment = new MockEnvironmentBuilder() + .setTaskName("mock-task") + .setTaskStateManager(taskStateManager) + .build()) { - OperatorChain<Void, AbstractStreamOperator<Void>> operatorChain = mock(OperatorChain.class); - when(operatorChain.getAllOperators()).thenReturn(streamOperators); + RunningTask<MockStreamTask> task = runTask(() -> new MockStreamTask( + mockEnvironment, + operatorChain(streamOperatorWithSnapshot(operatorSnapshotResult)))); - CheckpointStorage checkpointStorage = new MemoryBackendCheckpointStorage(new JobID(), null, null, Integer.MAX_VALUE); + MockStreamTask streamTask = task.streamTask; + waitTaskIsRunning(streamTask, task.invocationFuture); - Whitebox.setInternalState(streamTask, "isRunning", true); - Whitebox.setInternalState(streamTask, "lock", new Object()); - Whitebox.setInternalState(streamTask, "operatorChain", operatorChain); - Whitebox.setInternalState(streamTask, "cancelables", new CloseableRegistry()); - Whitebox.setInternalState(streamTask, "asyncOperationsThreadPool", Executors.newFixedThreadPool(1)); - Whitebox.setInternalState(streamTask, "configuration", new StreamConfig(new Configuration())); - Whitebox.setInternalState(streamTask, "checkpointStorage", checkpointStorage); + final long checkpointId = 42L; + streamTask.triggerCheckpoint( + new CheckpointMetaData(checkpointId, 1L), + CheckpointOptions.forCheckpointWithDefaultLocation(), + false); - streamTask.triggerCheckpoint(checkpointMetaData, CheckpointOptions.forCheckpointWithDefaultLocation(), false); + acknowledgeCheckpointLatch.await(); - acknowledgeCheckpointLatch.await(); + ArgumentCaptor<TaskStateSnapshot> subtaskStateCaptor = ArgumentCaptor.forClass(TaskStateSnapshot.class); - ArgumentCaptor<TaskStateSnapshot> subtaskStateCaptor = ArgumentCaptor.forClass(TaskStateSnapshot.class); + // check that the checkpoint has been completed + verify(checkpointResponder).acknowledgeCheckpoint( + any(JobID.class), + any(ExecutionAttemptID.class), + eq(checkpointId), + any(CheckpointMetrics.class), + subtaskStateCaptor.capture()); - // check that the checkpoint has been completed - verify(checkpointResponder).acknowledgeCheckpoint( - any(JobID.class), - any(ExecutionAttemptID.class), - eq(checkpointId), - any(CheckpointMetrics.class), - subtaskStateCaptor.capture()); + TaskStateSnapshot subtaskStates = subtaskStateCaptor.getValue(); + OperatorSubtaskState subtaskState = subtaskStates.getSubtaskStateMappings().iterator().next().getValue(); - TaskStateSnapshot subtaskStates = subtaskStateCaptor.getValue(); - OperatorSubtaskState subtaskState = subtaskStates.getSubtaskStateMappings().iterator().next().getValue(); + // check that the subtask state contains the expected state handles + assertEquals(StateObjectCollection.singleton(managedKeyedStateHandle), subtaskState.getManagedKeyedState()); + assertEquals(StateObjectCollection.singleton(rawKeyedStateHandle), subtaskState.getRawKeyedState()); + assertEquals(StateObjectCollection.singleton(managedOperatorStateHandle), subtaskState.getManagedOperatorState()); + assertEquals(StateObjectCollection.singleton(rawOperatorStateHandle), subtaskState.getRawOperatorState()); - // check that the subtask state contains the expected state handles - assertEquals(StateObjectCollection.singleton(managedKeyedStateHandle), subtaskState.getManagedKeyedState()); - assertEquals(StateObjectCollection.singleton(rawKeyedStateHandle), subtaskState.getRawKeyedState()); - assertEquals(StateObjectCollection.singleton(managedOperatorStateHandle), subtaskState.getManagedOperatorState()); - assertEquals(StateObjectCollection.singleton(rawOperatorStateHandle), subtaskState.getRawOperatorState()); + // check that the state handles have not been discarded + verify(managedKeyedStateHandle, never()).discardState(); + verify(rawKeyedStateHandle, never()).discardState(); + verify(managedOperatorStateHandle, never()).discardState(); + verify(rawOperatorStateHandle, never()).discardState(); - // check that the state handles have not been discarded - verify(managedKeyedStateHandle, never()).discardState(); - verify(rawKeyedStateHandle, never()).discardState(); - verify(managedOperatorStateHandle, never()).discardState(); - verify(rawOperatorStateHandle, never()).discardState(); + streamTask.cancel(); - streamTask.cancel(); + completeAcknowledge.trigger(); - completeAcknowledge.trigger(); + // canceling the stream task after it has acknowledged the checkpoint should not discard + // the state handles + verify(managedKeyedStateHandle, never()).discardState(); + verify(rawKeyedStateHandle, never()).discardState(); + verify(managedOperatorStateHandle, never()).discardState(); + verify(rawOperatorStateHandle, never()).discardState(); - // canceling the stream task after it has acknowledged the checkpoint should not discard - // the state handles - verify(managedKeyedStateHandle, never()).discardState(); - verify(rawKeyedStateHandle, never()).discardState(); - verify(managedOperatorStateHandle, never()).discardState(); - verify(rawOperatorStateHandle, never()).discardState(); + task.waitForTaskCompletion(true); + } } /** @@ -568,14 +515,10 @@ public class StreamTaskTest extends TestLogger { */ @Test public void testAsyncCheckpointingConcurrentCloseBeforeAcknowledge() throws Exception { - final long checkpointId = 42L; - final long timestamp = 1L; final OneShotLatch createSubtask = new OneShotLatch(); final OneShotLatch completeSubtask = new OneShotLatch(); - Environment mockEnvironment = spy(new MockEnvironmentBuilder().build()); - whenNew(OperatorSnapshotFinalizer.class). withAnyArguments(). thenAnswer((Answer<OperatorSnapshotFinalizer>) invocation -> { @@ -586,13 +529,6 @@ public class StreamTaskTest extends TestLogger { } ); - StreamTask<?, ?> streamTask = new EmptyStreamTask(mockEnvironment); - CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp); - - final StreamOperator<?> streamOperator = mock(StreamOperator.class); - final OperatorID operatorID = new OperatorID(); - when(streamOperator.getOperatorID()).thenReturn(operatorID); - KeyedStateHandle managedKeyedStateHandle = mock(KeyedStateHandle.class); KeyedStateHandle rawKeyedStateHandle = mock(KeyedStateHandle.class); OperatorStateHandle managedOperatorStateHandle = mock(OperatorStreamStateHandle.class); @@ -604,49 +540,47 @@ public class StreamTaskTest extends TestLogger { DoneFuture.of(SnapshotResult.of(managedOperatorStateHandle)), DoneFuture.of(SnapshotResult.of(rawOperatorStateHandle))); - when(streamOperator.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class), any(CheckpointStreamFactory.class))).thenReturn(operatorSnapshotResult); + final StreamOperator<?> streamOperator = streamOperatorWithSnapshot(operatorSnapshotResult); - StreamOperator<?>[] streamOperators = {streamOperator}; + try (MockEnvironment mockEnvironment = spy(new MockEnvironmentBuilder().build())) { - OperatorChain<Void, AbstractStreamOperator<Void>> operatorChain = mock(OperatorChain.class); - when(operatorChain.getAllOperators()).thenReturn(streamOperators); + RunningTask<MockStreamTask> task = runTask(() -> new MockStreamTask( + mockEnvironment, + operatorChain(streamOperator))); - CheckpointStorage checkpointStorage = new MemoryBackendCheckpointStorage(new JobID(), null, null, Integer.MAX_VALUE); + waitTaskIsRunning(task.streamTask, task.invocationFuture); - ExecutorService executor = Executors.newFixedThreadPool(1); + final long checkpointId = 42L; + task.streamTask.triggerCheckpoint( + new CheckpointMetaData(checkpointId, 1L), + CheckpointOptions.forCheckpointWithDefaultLocation(), + false); - Whitebox.setInternalState(streamTask, "isRunning", true); - Whitebox.setInternalState(streamTask, "lock", new Object()); - Whitebox.setInternalState(streamTask, "operatorChain", operatorChain); - Whitebox.setInternalState(streamTask, "cancelables", new CloseableRegistry()); - Whitebox.setInternalState(streamTask, "asyncOperationsThreadPool", executor); - Whitebox.setInternalState(streamTask, "configuration", new StreamConfig(new Configuration())); - Whitebox.setInternalState(streamTask, "checkpointStorage", checkpointStorage); + createSubtask.await(); - streamTask.triggerCheckpoint(checkpointMetaData, CheckpointOptions.forCheckpointWithDefaultLocation(), false); + task.streamTask.cancel(); - createSubtask.await(); + completeSubtask.trigger(); - streamTask.cancel(); + // wait for the completion of the async task + ExecutorService executor = task.streamTask.getAsyncOperationsThreadPool(); + executor.shutdown(); + if (!executor.awaitTermination(10000L, TimeUnit.MILLISECONDS)) { + fail("Executor did not shut down within the given timeout. This indicates that the " + + "checkpointing did not resume."); + } - completeSubtask.trigger(); + // check that the checkpoint has not been acknowledged + verify(mockEnvironment, never()).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), any(TaskStateSnapshot.class)); - // wait for the completion of the async task - executor.shutdown(); + // check that the state handles have been discarded + verify(managedKeyedStateHandle).discardState(); + verify(rawKeyedStateHandle).discardState(); + verify(managedOperatorStateHandle).discardState(); + verify(rawOperatorStateHandle).discardState(); - if (!executor.awaitTermination(10000L, TimeUnit.MILLISECONDS)) { - fail("Executor did not shut down within the given timeout. This indicates that the " + - "checkpointing did not resume."); + task.waitForTaskCompletion(true); } - - // check that the checkpoint has not been acknowledged - verify(mockEnvironment, never()).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), any(TaskStateSnapshot.class)); - - // check that the state handles have been discarded - verify(managedKeyedStateHandle).discardState(); - verify(rawKeyedStateHandle).discardState(); - verify(managedOperatorStateHandle).discardState(); - verify(rawOperatorStateHandle).discardState(); } /** @@ -657,10 +591,6 @@ public class StreamTaskTest extends TestLogger { */ @Test public void testEmptySubtaskStateLeadsToStatelessAcknowledgment() throws Exception { - final long checkpointId = 42L; - final long timestamp = 1L; - - Environment mockEnvironment = spy(new MockEnvironmentBuilder().build()); // latch blocks until the async checkpoint thread acknowledges final OneShotLatch checkpointCompletedLatch = new OneShotLatch(); @@ -689,42 +619,32 @@ public class StreamTaskTest extends TestLogger { null, checkpointResponder); - when(mockEnvironment.getTaskStateManager()).thenReturn(taskStateManager); + // mock the operator with empty snapshot result (all state handles are null) + StreamOperator<?> statelessOperator = streamOperatorWithSnapshot(new OperatorSnapshotFutures()); - StreamTask<?, ?> streamTask = new EmptyStreamTask(mockEnvironment); - CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp); - - // mock the operators - StreamOperator<?> statelessOperator = - mock(StreamOperator.class); + try (MockEnvironment mockEnvironment = new MockEnvironmentBuilder() + .setTaskStateManager(taskStateManager) + .build()) { - final OperatorID operatorID = new OperatorID(); - when(statelessOperator.getOperatorID()).thenReturn(operatorID); + RunningTask<MockStreamTask> task = runTask(() -> new MockStreamTask( + mockEnvironment, + operatorChain(statelessOperator))); - // mock the returned empty snapshot result (all state handles are null) - OperatorSnapshotFutures statelessOperatorSnapshotResult = new OperatorSnapshotFutures(); - when(statelessOperator.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class), any(CheckpointStreamFactory.class))) - .thenReturn(statelessOperatorSnapshotResult); + waitTaskIsRunning(task.streamTask, task.invocationFuture); - // set up the task - StreamOperator<?>[] streamOperators = {statelessOperator}; - OperatorChain<Void, AbstractStreamOperator<Void>> operatorChain = mock(OperatorChain.class); - when(operatorChain.getAllOperators()).thenReturn(streamOperators); + task.streamTask.triggerCheckpoint( + new CheckpointMetaData(42L, 1L), + CheckpointOptions.forCheckpointWithDefaultLocation(), + false); - Whitebox.setInternalState(streamTask, "isRunning", true); - Whitebox.setInternalState(streamTask, "lock", new Object()); - Whitebox.setInternalState(streamTask, "operatorChain", operatorChain); - Whitebox.setInternalState(streamTask, "cancelables", new CloseableRegistry()); - Whitebox.setInternalState(streamTask, "configuration", new StreamConfig(new Configuration())); - Whitebox.setInternalState(streamTask, "asyncOperationsThreadPool", Executors.newCachedThreadPool()); - Whitebox.setInternalState(streamTask, "checkpointStorage", new MemoryBackendCheckpointStorage(new JobID(), null, null, Integer.MAX_VALUE)); + checkpointCompletedLatch.await(30, TimeUnit.SECONDS); - streamTask.triggerCheckpoint(checkpointMetaData, CheckpointOptions.forCheckpointWithDefaultLocation(), false); - checkpointCompletedLatch.await(30, TimeUnit.SECONDS); - streamTask.cancel(); + // ensure that 'null' was acknowledged as subtask state + Assert.assertNull(checkpointResult.get(0)); - // ensure that 'null' was acknowledged as subtask state - Assert.assertNull(checkpointResult.get(0)); + task.streamTask.cancel(); + task.waitForTaskCompletion(true); + } } /** @@ -748,37 +668,21 @@ public class StreamTaskTest extends TestLogger { .setBufferSize(1) .setTaskConfiguration(taskConfiguration) .build()) { - StreamTask<Void, BlockingCloseStreamOperator> streamTask = new NoOpStreamTask<>(mockEnvironment); - final AtomicReference<Throwable> atomicThrowable = new AtomicReference<>(null); - CompletableFuture<Void> invokeFuture = CompletableFuture.runAsync( - () -> { - try { - streamTask.invoke(); - } catch (Exception e) { - atomicThrowable.set(e); - } - }, - TestingUtils.defaultExecutor()); + RunningTask<StreamTask<Void, BlockingCloseStreamOperator>> task = runTask(() -> new NoOpStreamTask<>(mockEnvironment)); BlockingCloseStreamOperator.IN_CLOSE.await(); // check that the StreamTask is not yet in isRunning == false - assertTrue(streamTask.isRunning()); + assertTrue(task.streamTask.isRunning()); // let the operator finish its close operation BlockingCloseStreamOperator.FINISH_CLOSE.trigger(); - // wait until the invoke is complete - invokeFuture.get(); + task.waitForTaskCompletion(false); // now the StreamTask should no longer be running - assertFalse(streamTask.isRunning()); - - // check if an exception occurred - if (atomicThrowable.get() != null) { - throw atomicThrowable.get(); - } + assertFalse(task.streamTask.isRunning()); } } @@ -793,22 +697,11 @@ public class StreamTaskTest extends TestLogger { new MockEnvironmentBuilder() .setUserCodeClassLoader(new TestUserCodeClassLoader()) .build()) { - TimeServiceTask timerServiceTask = new TimeServiceTask(mockEnvironment); + RunningTask<TimeServiceTask> task = runTask(() -> new TimeServiceTask(mockEnvironment)); + task.waitForTaskCompletion(false); - CompletableFuture<Void> invokeFuture = CompletableFuture.runAsync( - () -> { - try { - timerServiceTask.invoke(); - } catch (Exception e) { - throw new CompletionException(e); - } - }, - TestingUtils.defaultExecutor()); - - invokeFuture.get(); - - assertThat(timerServiceTask.getClassLoaders(), hasSize(greaterThanOrEqualTo(1))); - assertThat(timerServiceTask.getClassLoaders(), everyItem(instanceOf(TestUserCodeClassLoader.class))); + assertThat(task.streamTask.getClassLoaders(), hasSize(greaterThanOrEqualTo(1))); + assertThat(task.streamTask.getClassLoaders(), everyItem(instanceOf(TestUserCodeClassLoader.class))); } } @@ -816,6 +709,80 @@ public class StreamTaskTest extends TestLogger { // Test Utilities // ------------------------------------------------------------------------ + private static StreamOperator<?> streamOperatorWithSnapshot(OperatorSnapshotFutures operatorSnapshotResult) throws Exception { + StreamOperator<?> operator = mock(StreamOperator.class); + when(operator.getOperatorID()).thenReturn(new OperatorID()); + + when(operator.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class), any(CheckpointStreamFactory.class))) + .thenReturn(operatorSnapshotResult); + + return operator; + } + + private static StreamOperator<?> streamOperatorWithSnapshotException(Exception exception) throws Exception { + StreamOperator<?> operator = mock(StreamOperator.class); + when(operator.getOperatorID()).thenReturn(new OperatorID()); + + when(operator.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class), any(CheckpointStreamFactory.class))) + .thenThrow(exception); + + return operator; + } + + private static <T> OperatorChain<T, AbstractStreamOperator<T>> operatorChain(StreamOperator<?>... streamOperators) { + OperatorChain<T, AbstractStreamOperator<T>> operatorChain = mock(OperatorChain.class); + when(operatorChain.getAllOperators()).thenReturn(streamOperators); + return operatorChain; + } + + private static class RunningTask<T extends StreamTask<?, ?>> { + final T streamTask; + final CompletableFuture<Void> invocationFuture; + + RunningTask(T streamTask, CompletableFuture<Void> invocationFuture) { + this.streamTask = streamTask; + this.invocationFuture = invocationFuture; + } + + void waitForTaskCompletion(boolean cancelled) throws Exception { + try { + invocationFuture.get(); + } catch (Exception e) { + if (cancelled) { + assertThat(e.getCause(), is(instanceOf(CancelTaskException.class))); + } else { + throw e; + } + } + assertThat(streamTask.isCanceled(), is(cancelled)); + } + } + + private static <T extends StreamTask<?, ?>> RunningTask<T> runTask(SupplierWithException<T, Exception> taskFactory) throws Exception { + CompletableFuture<T> taskCreationFuture = new CompletableFuture<>(); + CompletableFuture<Void> invocationFuture = CompletableFuture.runAsync( + () -> { + T task; + try { + task = taskFactory.get(); + taskCreationFuture.complete(task); + } catch (Exception e) { + taskCreationFuture.completeExceptionally(e); + return; + } + try { + task.invoke(); + } catch (RuntimeException e) { + throw e; + } catch (Exception e) { + throw new RuntimeException(e); + } + }, Executors.newSingleThreadExecutor()); + + // Wait until task is created. + return new RunningTask<>(taskCreationFuture.get(), invocationFuture); + } + /** * Operator that does nothing. * @@ -885,7 +852,6 @@ public class StreamTaskTest extends TestLogger { LibraryCacheManager libCache = mock(LibraryCacheManager.class); when(libCache.getClassLoader(any(JobID.class))).thenReturn(StreamTaskTest.class.getClassLoader()); - ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); ResultPartitionConsumableNotifier consumableNotifier = new NoOpResultPartitionConsumableNotifier(); PartitionProducerStateChecker partitionProducerStateChecker = mock(PartitionProducerStateChecker.class); Executor executor = mock(Executor.class); @@ -1012,18 +978,30 @@ public class StreamTaskTest extends TestLogger { // ------------------------------------------------------------------------ // ------------------------------------------------------------------------ - private static class EmptyStreamTask extends StreamTask<String, AbstractStreamOperator<String>> { + private static class MockStreamTask extends StreamTask<String, AbstractStreamOperator<String>> { - public EmptyStreamTask(Environment env) { + private final OperatorChain<String, AbstractStreamOperator<String>> overrideOperatorChain; + private volatile boolean inputFinished; + + MockStreamTask(Environment env, OperatorChain<String, AbstractStreamOperator<String>> operatorChain) { super(env, null); + this.overrideOperatorChain = operatorChain; } @Override - protected void init() throws Exception {} + protected void init() { + // The StreamTask initializes operatorChain first on it's own in `invoke()` method. + // Later it calls the `init()` method before actual `run()`, so we are overriding the operatorChain + // here for test purposes. + super.operatorChain = this.overrideOperatorChain; + super.headOperator = super.operatorChain.getHeadOperator(); + } @Override - protected void performDefaultAction(ActionContext context) throws Exception { - context.allActionsCompleted(); + protected void performDefaultAction(ActionContext context) { + if (isCanceled() || inputFinished) { + context.allActionsCompleted(); + } } @Override @@ -1031,6 +1009,10 @@ public class StreamTaskTest extends TestLogger { @Override protected void cancelTask() throws Exception {} + + void finishInput() { + this.inputFinished = true; + } } /** diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/StreamTaskUtil.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/StreamTaskUtil.java new file mode 100644 index 0000000..17c61ac --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/StreamTaskUtil.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.util; + +import org.apache.flink.streaming.runtime.tasks.StreamTask; + +import java.util.concurrent.CompletableFuture; + +import static org.junit.Assert.fail; + +/** + * Utils for working with StreamTask. + */ +public class StreamTaskUtil { + + public static void waitTaskIsRunning(StreamTask<?, ?> task, CompletableFuture<Void> taskInvocation) throws InterruptedException { + while (!task.isRunning()) { + if (taskInvocation.isDone()) { + fail("Task has stopped"); + } + Thread.sleep(10L); + } + } +}
