This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit db53a04cdec88a9b74e17b028a3a6922e14e45d4
Author: Aleksey Pak <[email protected]>
AuthorDate: Thu Jun 20 17:08:52 2019 +0200

    [hotfix][tests] Rewrite StreamTaskTest without reflection based fields 
setting
---
 .../streaming/runtime/tasks/StreamTaskTest.java    | 522 ++++++++++-----------
 .../flink/streaming/util/StreamTaskUtil.java       |  40 ++
 2 files changed, 292 insertions(+), 270 deletions(-)

diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index e2171b7..46b5389 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -20,13 +20,10 @@ package org.apache.flink.streaming.runtime.tasks;
 
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.TaskInfo;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.configuration.CheckpointingOptions;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.core.testutils.OneShotLatch;
-import org.apache.flink.mock.Whitebox;
 import org.apache.flink.runtime.blob.BlobCacheService;
 import org.apache.flink.runtime.blob.PermanentBlobCache;
 import org.apache.flink.runtime.blob.TransientBlobCache;
@@ -41,6 +38,7 @@ import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
+import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager;
@@ -53,7 +51,6 @@ import 
org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
 import org.apache.flink.runtime.io.network.TaskEventDispatcher;
 import 
org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier;
 import 
org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
-import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
@@ -67,7 +64,6 @@ import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
-import org.apache.flink.runtime.state.CheckpointStorage;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
@@ -83,7 +79,6 @@ import org.apache.flink.runtime.state.TaskLocalStateStoreImpl;
 import org.apache.flink.runtime.state.TaskStateManager;
 import org.apache.flink.runtime.state.TaskStateManagerImpl;
 import org.apache.flink.runtime.state.TestTaskStateManager;
-import org.apache.flink.runtime.state.memory.MemoryBackendCheckpointStorage;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.taskexecutor.KvStateService;
 import org.apache.flink.runtime.taskexecutor.PartitionProducerStateChecker;
@@ -93,7 +88,6 @@ import 
org.apache.flink.runtime.taskmanager.NoOpTaskManagerActions;
 import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.runtime.taskmanager.TaskManagerActions;
-import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
@@ -112,9 +106,12 @@ import 
org.apache.flink.streaming.runtime.streamstatus.StreamStatusMaintainer;
 import org.apache.flink.util.CloseableIterable;
 import org.apache.flink.util.SerializedValue;
 import org.apache.flink.util.TestLogger;
+import org.apache.flink.util.function.SupplierWithException;
 
 import org.junit.Assert;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.Timeout;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
 import org.mockito.invocation.InvocationOnMock;
@@ -130,20 +127,19 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.CompletionException;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executor;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.RunnableFuture;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicReference;
 
-import static 
org.apache.flink.runtime.concurrent.Executors.newDirectExecutorService;
+import static org.apache.flink.streaming.util.StreamTaskUtil.waitTaskIsRunning;
 import static org.hamcrest.Matchers.everyItem;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThat;
@@ -152,7 +148,6 @@ import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.nullable;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyLong;
-import static org.mockito.Matchers.anyString;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
@@ -174,6 +169,9 @@ public class StreamTaskTest extends TestLogger {
 
        private static OneShotLatch syncLatch;
 
+       @Rule
+       public final Timeout timeoutPerTest = Timeout.seconds(30);
+
        /**
         * This test checks that cancel calls that are issued before the 
operator is
         * instantiated still lead to proper canceling.
@@ -311,63 +309,36 @@ public class StreamTaskTest extends TestLogger {
 
        @Test
        public void testDecliningCheckpointStreamOperator() throws Exception {
-               final long checkpointId = 42L;
-               final long timestamp = 1L;
-
-               TaskInfo mockTaskInfo = mock(TaskInfo.class);
-               
when(mockTaskInfo.getTaskNameWithSubtasks()).thenReturn("foobar");
-               when(mockTaskInfo.getIndexOfThisSubtask()).thenReturn(0);
                CheckpointExceptionHandlerTest.DeclineDummyEnvironment 
declineDummyEnvironment = new 
CheckpointExceptionHandlerTest.DeclineDummyEnvironment();
 
-               StreamTask<?, ?> streamTask = new 
EmptyStreamTask(declineDummyEnvironment);
-               CheckpointMetaData checkpointMetaData = new 
CheckpointMetaData(checkpointId, timestamp);
-
-               // mock the operators
-               StreamOperator<?> streamOperator1 = mock(StreamOperator.class);
-               StreamOperator<?> streamOperator2 = mock(StreamOperator.class);
-               StreamOperator<?> streamOperator3 = mock(StreamOperator.class);
-
                // mock the returned snapshots
                OperatorSnapshotFutures operatorSnapshotResult1 = 
mock(OperatorSnapshotFutures.class);
                OperatorSnapshotFutures operatorSnapshotResult2 = 
mock(OperatorSnapshotFutures.class);
 
                final Exception testException = new Exception("Test exception");
 
-               when(streamOperator1.snapshotState(anyLong(), anyLong(), 
any(CheckpointOptions.class), 
any(CheckpointStreamFactory.class))).thenReturn(operatorSnapshotResult1);
-               when(streamOperator2.snapshotState(anyLong(), anyLong(), 
any(CheckpointOptions.class), 
any(CheckpointStreamFactory.class))).thenReturn(operatorSnapshotResult2);
-               when(streamOperator3.snapshotState(anyLong(), anyLong(), 
any(CheckpointOptions.class), 
any(CheckpointStreamFactory.class))).thenThrow(testException);
-
-               OperatorID operatorID1 = new OperatorID();
-               OperatorID operatorID2 = new OperatorID();
-               OperatorID operatorID3 = new OperatorID();
-               when(streamOperator1.getOperatorID()).thenReturn(operatorID1);
-               when(streamOperator2.getOperatorID()).thenReturn(operatorID2);
-               when(streamOperator3.getOperatorID()).thenReturn(operatorID3);
-
-               // set up the task
+               RunningTask<MockStreamTask> task = runTask(() -> new 
MockStreamTask(
+                       declineDummyEnvironment,
+                       operatorChain(
+                               
streamOperatorWithSnapshot(operatorSnapshotResult1),
+                               
streamOperatorWithSnapshot(operatorSnapshotResult2),
+                               
streamOperatorWithSnapshotException(testException))));
+               MockStreamTask streamTask = task.streamTask;
 
-               StreamOperator<?>[] streamOperators = {streamOperator1, 
streamOperator2, streamOperator3};
+               waitTaskIsRunning(streamTask, task.invocationFuture);
 
-               OperatorChain<Void, AbstractStreamOperator<Void>> operatorChain 
= mock(OperatorChain.class);
-               
when(operatorChain.getAllOperators()).thenReturn(streamOperators);
-
-               Whitebox.setInternalState(streamTask, "isRunning", true);
-               Whitebox.setInternalState(streamTask, "lock", new Object());
-               Whitebox.setInternalState(streamTask, "operatorChain", 
operatorChain);
-               Whitebox.setInternalState(streamTask, "cancelables", new 
CloseableRegistry());
-               Whitebox.setInternalState(streamTask, "configuration", new 
StreamConfig(new Configuration()));
-               Whitebox.setInternalState(streamTask, "checkpointStorage", new 
MemoryBackendCheckpointStorage(new JobID(), null, null, Integer.MAX_VALUE));
-
-               CheckpointExceptionHandlerFactory 
checkpointExceptionHandlerFactory = new CheckpointExceptionHandlerFactory();
-               CheckpointExceptionHandler checkpointExceptionHandler =
-                       
checkpointExceptionHandlerFactory.createCheckpointExceptionHandler(declineDummyEnvironment);
-               Whitebox.setInternalState(streamTask, 
"checkpointExceptionHandler", checkpointExceptionHandler);
+               streamTask.triggerCheckpoint(
+                       new CheckpointMetaData(42L, 1L),
+                       CheckpointOptions.forCheckpointWithDefaultLocation(),
+                       false);
 
-               streamTask.triggerCheckpoint(checkpointMetaData, 
CheckpointOptions.forCheckpointWithDefaultLocation(), false);
                assertEquals(testException, 
declineDummyEnvironment.getLastDeclinedCheckpointCause());
 
                verify(operatorSnapshotResult1).cancel();
                verify(operatorSnapshotResult2).cancel();
+
+               task.streamTask.finishInput();
+               task.waitForTaskCompletion(false);
        }
 
        /**
@@ -376,17 +347,6 @@ public class StreamTaskTest extends TestLogger {
         */
        @Test
        public void testFailingAsyncCheckpointRunnable() throws Exception {
-               final long checkpointId = 42L;
-               final long timestamp = 1L;
-
-               MockEnvironment mockEnvironment = new 
MockEnvironmentBuilder().build();
-               StreamTask<?, ?> streamTask = spy(new 
EmptyStreamTask(mockEnvironment));
-               CheckpointMetaData checkpointMetaData = new 
CheckpointMetaData(checkpointId, timestamp);
-
-               // mock the operators
-               StreamOperator<?> streamOperator1 = mock(StreamOperator.class);
-               StreamOperator<?> streamOperator2 = mock(StreamOperator.class);
-               StreamOperator<?> streamOperator3 = mock(StreamOperator.class);
 
                // mock the new state operator snapshots
                OperatorSnapshotFutures operatorSnapshotResult1 = 
mock(OperatorSnapshotFutures.class);
@@ -398,43 +358,41 @@ public class StreamTaskTest extends TestLogger {
 
                
when(operatorSnapshotResult3.getOperatorStateRawFuture()).thenReturn(failingFuture);
 
-               when(streamOperator1.snapshotState(anyLong(), anyLong(), 
any(CheckpointOptions.class), 
any(CheckpointStreamFactory.class))).thenReturn(operatorSnapshotResult1);
-               when(streamOperator2.snapshotState(anyLong(), anyLong(), 
any(CheckpointOptions.class), 
any(CheckpointStreamFactory.class))).thenReturn(operatorSnapshotResult2);
-               when(streamOperator3.snapshotState(anyLong(), anyLong(), 
any(CheckpointOptions.class), 
any(CheckpointStreamFactory.class))).thenReturn(operatorSnapshotResult3);
-
-               OperatorID operatorID1 = new OperatorID();
-               OperatorID operatorID2 = new OperatorID();
-               OperatorID operatorID3 = new OperatorID();
-               when(streamOperator1.getOperatorID()).thenReturn(operatorID1);
-               when(streamOperator2.getOperatorID()).thenReturn(operatorID2);
-               when(streamOperator3.getOperatorID()).thenReturn(operatorID3);
-
-               StreamOperator<?>[] streamOperators = {streamOperator1, 
streamOperator2, streamOperator3};
-
-               OperatorChain<Void, AbstractStreamOperator<Void>> operatorChain 
= mock(OperatorChain.class);
-               
when(operatorChain.getAllOperators()).thenReturn(streamOperators);
-
-               Whitebox.setInternalState(streamTask, "isRunning", true);
-               Whitebox.setInternalState(streamTask, "lock", new Object());
-               Whitebox.setInternalState(streamTask, "operatorChain", 
operatorChain);
-               Whitebox.setInternalState(streamTask, "cancelables", new 
CloseableRegistry());
-               Whitebox.setInternalState(streamTask, 
"asyncOperationsThreadPool", newDirectExecutorService());
-               Whitebox.setInternalState(streamTask, "configuration", new 
StreamConfig(new Configuration()));
-               Whitebox.setInternalState(streamTask, "checkpointStorage", new 
MemoryBackendCheckpointStorage(new JobID(), null, null, Integer.MAX_VALUE));
-
-               CheckpointExceptionHandlerFactory 
checkpointExceptionHandlerFactory = new CheckpointExceptionHandlerFactory();
-               CheckpointExceptionHandler checkpointExceptionHandler =
-                       
checkpointExceptionHandlerFactory.createCheckpointExceptionHandler(mockEnvironment);
-               Whitebox.setInternalState(streamTask, 
"checkpointExceptionHandler", checkpointExceptionHandler);
+               try (MockEnvironment mockEnvironment = new 
MockEnvironmentBuilder().build()) {
+                       RunningTask<MockStreamTask> task = runTask(() -> new 
MockStreamTask(
+                               mockEnvironment,
+                               operatorChain(
+                                       
streamOperatorWithSnapshot(operatorSnapshotResult1),
+                                       
streamOperatorWithSnapshot(operatorSnapshotResult2),
+                                       
streamOperatorWithSnapshot(operatorSnapshotResult3))));
+
+                       MockStreamTask streamTask = task.streamTask;
+
+                       waitTaskIsRunning(streamTask, task.invocationFuture);
+
+                       
mockEnvironment.setExpectedExternalFailureCause(Throwable.class);
+                       streamTask.triggerCheckpoint(
+                               new CheckpointMetaData(42L, 1L),
+                               
CheckpointOptions.forCheckpointWithDefaultLocation(),
+                               false);
+
+                       // wait for the completion of the async task
+                       ExecutorService executor = 
streamTask.getAsyncOperationsThreadPool();
+                       executor.shutdown();
+                       if (!executor.awaitTermination(10000L, 
TimeUnit.MILLISECONDS)) {
+                               fail("Executor did not shut down within the 
given timeout. This indicates that the " +
+                                       "checkpointing did not resume.");
+                       }
 
-               
mockEnvironment.setExpectedExternalFailureCause(Throwable.class);
-               streamTask.triggerCheckpoint(checkpointMetaData, 
CheckpointOptions.forCheckpointWithDefaultLocation(), false);
+                       
assertTrue(mockEnvironment.getActualExternalFailureCause().isPresent());
 
-               verify(streamTask).handleAsyncException(anyString(), 
any(Throwable.class));
+                       verify(operatorSnapshotResult1).cancel();
+                       verify(operatorSnapshotResult2).cancel();
+                       verify(operatorSnapshotResult3).cancel();
 
-               verify(operatorSnapshotResult1).cancel();
-               verify(operatorSnapshotResult2).cancel();
-               verify(operatorSnapshotResult3).cancel();
+                       streamTask.finishInput();
+                       task.waitForTaskCompletion(false);
+               }
        }
 
        /**
@@ -447,8 +405,6 @@ public class StreamTaskTest extends TestLogger {
         */
        @Test
        public void testAsyncCheckpointingConcurrentCloseAfterAcknowledge() 
throws Exception {
-               final long checkpointId = 42L;
-               final long timestamp = 1L;
 
                final OneShotLatch acknowledgeCheckpointLatch = new 
OneShotLatch();
                final OneShotLatch completeAcknowledge = new OneShotLatch();
@@ -478,17 +434,6 @@ public class StreamTaskTest extends TestLogger {
                        null,
                        checkpointResponder);
 
-               MockEnvironment mockEnvironment = new MockEnvironmentBuilder()
-                       .setTaskName("mock-task")
-                       .setTaskStateManager(taskStateManager)
-                       .build();
-
-               StreamTask<?, ?> streamTask = new 
EmptyStreamTask(mockEnvironment);
-               CheckpointMetaData checkpointMetaData = new 
CheckpointMetaData(checkpointId, timestamp);
-
-               StreamOperator<?> streamOperator = mock(StreamOperator.class);
-               when(streamOperator.getOperatorID()).thenReturn(new 
OperatorID(42, 42));
-
                KeyedStateHandle managedKeyedStateHandle = 
mock(KeyedStateHandle.class);
                KeyedStateHandle rawKeyedStateHandle = 
mock(KeyedStateHandle.class);
                OperatorStateHandle managedOperatorStateHandle = 
mock(OperatorStreamStateHandle.class);
@@ -500,62 +445,64 @@ public class StreamTaskTest extends TestLogger {
                        
DoneFuture.of(SnapshotResult.of(managedOperatorStateHandle)),
                        
DoneFuture.of(SnapshotResult.of(rawOperatorStateHandle)));
 
-               when(streamOperator.snapshotState(anyLong(), anyLong(), 
any(CheckpointOptions.class), 
any(CheckpointStreamFactory.class))).thenReturn(operatorSnapshotResult);
-
-               StreamOperator<?>[] streamOperators = {streamOperator};
+               try (MockEnvironment mockEnvironment = new 
MockEnvironmentBuilder()
+                       .setTaskName("mock-task")
+                       .setTaskStateManager(taskStateManager)
+                       .build()) {
 
-               OperatorChain<Void, AbstractStreamOperator<Void>> operatorChain 
= mock(OperatorChain.class);
-               
when(operatorChain.getAllOperators()).thenReturn(streamOperators);
+                       RunningTask<MockStreamTask> task = runTask(() -> new 
MockStreamTask(
+                               mockEnvironment,
+                               
operatorChain(streamOperatorWithSnapshot(operatorSnapshotResult))));
 
-               CheckpointStorage checkpointStorage = new 
MemoryBackendCheckpointStorage(new JobID(), null, null, Integer.MAX_VALUE);
+                       MockStreamTask streamTask = task.streamTask;
+                       waitTaskIsRunning(streamTask, task.invocationFuture);
 
-               Whitebox.setInternalState(streamTask, "isRunning", true);
-               Whitebox.setInternalState(streamTask, "lock", new Object());
-               Whitebox.setInternalState(streamTask, "operatorChain", 
operatorChain);
-               Whitebox.setInternalState(streamTask, "cancelables", new 
CloseableRegistry());
-               Whitebox.setInternalState(streamTask, 
"asyncOperationsThreadPool", Executors.newFixedThreadPool(1));
-               Whitebox.setInternalState(streamTask, "configuration", new 
StreamConfig(new Configuration()));
-               Whitebox.setInternalState(streamTask, "checkpointStorage", 
checkpointStorage);
+                       final long checkpointId = 42L;
+                       streamTask.triggerCheckpoint(
+                               new CheckpointMetaData(checkpointId, 1L),
+                               
CheckpointOptions.forCheckpointWithDefaultLocation(),
+                               false);
 
-               streamTask.triggerCheckpoint(checkpointMetaData, 
CheckpointOptions.forCheckpointWithDefaultLocation(), false);
+                       acknowledgeCheckpointLatch.await();
 
-               acknowledgeCheckpointLatch.await();
+                       ArgumentCaptor<TaskStateSnapshot> subtaskStateCaptor = 
ArgumentCaptor.forClass(TaskStateSnapshot.class);
 
-               ArgumentCaptor<TaskStateSnapshot> subtaskStateCaptor = 
ArgumentCaptor.forClass(TaskStateSnapshot.class);
+                       // check that the checkpoint has been completed
+                       verify(checkpointResponder).acknowledgeCheckpoint(
+                               any(JobID.class),
+                               any(ExecutionAttemptID.class),
+                               eq(checkpointId),
+                               any(CheckpointMetrics.class),
+                               subtaskStateCaptor.capture());
 
-               // check that the checkpoint has been completed
-               verify(checkpointResponder).acknowledgeCheckpoint(
-                       any(JobID.class),
-                       any(ExecutionAttemptID.class),
-                       eq(checkpointId),
-                       any(CheckpointMetrics.class),
-                       subtaskStateCaptor.capture());
+                       TaskStateSnapshot subtaskStates = 
subtaskStateCaptor.getValue();
+                       OperatorSubtaskState subtaskState = 
subtaskStates.getSubtaskStateMappings().iterator().next().getValue();
 
-               TaskStateSnapshot subtaskStates = subtaskStateCaptor.getValue();
-               OperatorSubtaskState subtaskState = 
subtaskStates.getSubtaskStateMappings().iterator().next().getValue();
+                       // check that the subtask state contains the expected 
state handles
+                       
assertEquals(StateObjectCollection.singleton(managedKeyedStateHandle), 
subtaskState.getManagedKeyedState());
+                       
assertEquals(StateObjectCollection.singleton(rawKeyedStateHandle), 
subtaskState.getRawKeyedState());
+                       
assertEquals(StateObjectCollection.singleton(managedOperatorStateHandle), 
subtaskState.getManagedOperatorState());
+                       
assertEquals(StateObjectCollection.singleton(rawOperatorStateHandle), 
subtaskState.getRawOperatorState());
 
-               // check that the subtask state contains the expected state 
handles
-               
assertEquals(StateObjectCollection.singleton(managedKeyedStateHandle), 
subtaskState.getManagedKeyedState());
-               
assertEquals(StateObjectCollection.singleton(rawKeyedStateHandle), 
subtaskState.getRawKeyedState());
-               
assertEquals(StateObjectCollection.singleton(managedOperatorStateHandle), 
subtaskState.getManagedOperatorState());
-               
assertEquals(StateObjectCollection.singleton(rawOperatorStateHandle), 
subtaskState.getRawOperatorState());
+                       // check that the state handles have not been discarded
+                       verify(managedKeyedStateHandle, never()).discardState();
+                       verify(rawKeyedStateHandle, never()).discardState();
+                       verify(managedOperatorStateHandle, 
never()).discardState();
+                       verify(rawOperatorStateHandle, never()).discardState();
 
-               // check that the state handles have not been discarded
-               verify(managedKeyedStateHandle, never()).discardState();
-               verify(rawKeyedStateHandle, never()).discardState();
-               verify(managedOperatorStateHandle, never()).discardState();
-               verify(rawOperatorStateHandle, never()).discardState();
+                       streamTask.cancel();
 
-               streamTask.cancel();
+                       completeAcknowledge.trigger();
 
-               completeAcknowledge.trigger();
+                       // canceling the stream task after it has acknowledged 
the checkpoint should not discard
+                       // the state handles
+                       verify(managedKeyedStateHandle, never()).discardState();
+                       verify(rawKeyedStateHandle, never()).discardState();
+                       verify(managedOperatorStateHandle, 
never()).discardState();
+                       verify(rawOperatorStateHandle, never()).discardState();
 
-               // canceling the stream task after it has acknowledged the 
checkpoint should not discard
-               // the state handles
-               verify(managedKeyedStateHandle, never()).discardState();
-               verify(rawKeyedStateHandle, never()).discardState();
-               verify(managedOperatorStateHandle, never()).discardState();
-               verify(rawOperatorStateHandle, never()).discardState();
+                       task.waitForTaskCompletion(true);
+               }
        }
 
        /**
@@ -568,14 +515,10 @@ public class StreamTaskTest extends TestLogger {
         */
        @Test
        public void testAsyncCheckpointingConcurrentCloseBeforeAcknowledge() 
throws Exception {
-               final long checkpointId = 42L;
-               final long timestamp = 1L;
 
                final OneShotLatch createSubtask = new OneShotLatch();
                final OneShotLatch completeSubtask = new OneShotLatch();
 
-               Environment mockEnvironment = spy(new 
MockEnvironmentBuilder().build());
-
                whenNew(OperatorSnapshotFinalizer.class).
                        withAnyArguments().
                        thenAnswer((Answer<OperatorSnapshotFinalizer>) 
invocation -> {
@@ -586,13 +529,6 @@ public class StreamTaskTest extends TestLogger {
                                }
                        );
 
-               StreamTask<?, ?> streamTask = new 
EmptyStreamTask(mockEnvironment);
-               CheckpointMetaData checkpointMetaData = new 
CheckpointMetaData(checkpointId, timestamp);
-
-               final StreamOperator<?> streamOperator = 
mock(StreamOperator.class);
-               final OperatorID operatorID = new OperatorID();
-               when(streamOperator.getOperatorID()).thenReturn(operatorID);
-
                KeyedStateHandle managedKeyedStateHandle = 
mock(KeyedStateHandle.class);
                KeyedStateHandle rawKeyedStateHandle = 
mock(KeyedStateHandle.class);
                OperatorStateHandle managedOperatorStateHandle = 
mock(OperatorStreamStateHandle.class);
@@ -604,49 +540,47 @@ public class StreamTaskTest extends TestLogger {
                        
DoneFuture.of(SnapshotResult.of(managedOperatorStateHandle)),
                        
DoneFuture.of(SnapshotResult.of(rawOperatorStateHandle)));
 
-               when(streamOperator.snapshotState(anyLong(), anyLong(), 
any(CheckpointOptions.class), 
any(CheckpointStreamFactory.class))).thenReturn(operatorSnapshotResult);
+               final StreamOperator<?> streamOperator = 
streamOperatorWithSnapshot(operatorSnapshotResult);
 
-               StreamOperator<?>[] streamOperators = {streamOperator};
+               try (MockEnvironment mockEnvironment = spy(new 
MockEnvironmentBuilder().build())) {
 
-               OperatorChain<Void, AbstractStreamOperator<Void>> operatorChain 
= mock(OperatorChain.class);
-               
when(operatorChain.getAllOperators()).thenReturn(streamOperators);
+                       RunningTask<MockStreamTask> task = runTask(() -> new 
MockStreamTask(
+                               mockEnvironment,
+                               operatorChain(streamOperator)));
 
-               CheckpointStorage checkpointStorage = new 
MemoryBackendCheckpointStorage(new JobID(), null, null, Integer.MAX_VALUE);
+                       waitTaskIsRunning(task.streamTask, 
task.invocationFuture);
 
-               ExecutorService executor = Executors.newFixedThreadPool(1);
+                       final long checkpointId = 42L;
+                       task.streamTask.triggerCheckpoint(
+                               new CheckpointMetaData(checkpointId, 1L),
+                               
CheckpointOptions.forCheckpointWithDefaultLocation(),
+                               false);
 
-               Whitebox.setInternalState(streamTask, "isRunning", true);
-               Whitebox.setInternalState(streamTask, "lock", new Object());
-               Whitebox.setInternalState(streamTask, "operatorChain", 
operatorChain);
-               Whitebox.setInternalState(streamTask, "cancelables", new 
CloseableRegistry());
-               Whitebox.setInternalState(streamTask, 
"asyncOperationsThreadPool", executor);
-               Whitebox.setInternalState(streamTask, "configuration", new 
StreamConfig(new Configuration()));
-               Whitebox.setInternalState(streamTask, "checkpointStorage", 
checkpointStorage);
+                       createSubtask.await();
 
-               streamTask.triggerCheckpoint(checkpointMetaData, 
CheckpointOptions.forCheckpointWithDefaultLocation(), false);
+                       task.streamTask.cancel();
 
-               createSubtask.await();
+                       completeSubtask.trigger();
 
-               streamTask.cancel();
+                       // wait for the completion of the async task
+                       ExecutorService executor = 
task.streamTask.getAsyncOperationsThreadPool();
+                       executor.shutdown();
+                       if (!executor.awaitTermination(10000L, 
TimeUnit.MILLISECONDS)) {
+                               fail("Executor did not shut down within the 
given timeout. This indicates that the " +
+                                       "checkpointing did not resume.");
+                       }
 
-               completeSubtask.trigger();
+                       // check that the checkpoint has not been acknowledged
+                       verify(mockEnvironment, 
never()).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), 
any(TaskStateSnapshot.class));
 
-               // wait for the completion of the async task
-               executor.shutdown();
+                       // check that the state handles have been discarded
+                       verify(managedKeyedStateHandle).discardState();
+                       verify(rawKeyedStateHandle).discardState();
+                       verify(managedOperatorStateHandle).discardState();
+                       verify(rawOperatorStateHandle).discardState();
 
-               if (!executor.awaitTermination(10000L, TimeUnit.MILLISECONDS)) {
-                       fail("Executor did not shut down within the given 
timeout. This indicates that the " +
-                               "checkpointing did not resume.");
+                       task.waitForTaskCompletion(true);
                }
-
-               // check that the checkpoint has not been acknowledged
-               verify(mockEnvironment, 
never()).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), 
any(TaskStateSnapshot.class));
-
-               // check that the state handles have been discarded
-               verify(managedKeyedStateHandle).discardState();
-               verify(rawKeyedStateHandle).discardState();
-               verify(managedOperatorStateHandle).discardState();
-               verify(rawOperatorStateHandle).discardState();
        }
 
        /**
@@ -657,10 +591,6 @@ public class StreamTaskTest extends TestLogger {
         */
        @Test
        public void testEmptySubtaskStateLeadsToStatelessAcknowledgment() 
throws Exception {
-               final long checkpointId = 42L;
-               final long timestamp = 1L;
-
-               Environment mockEnvironment = spy(new 
MockEnvironmentBuilder().build());
 
                // latch blocks until the async checkpoint thread acknowledges
                final OneShotLatch checkpointCompletedLatch = new 
OneShotLatch();
@@ -689,42 +619,32 @@ public class StreamTaskTest extends TestLogger {
                        null,
                        checkpointResponder);
 
-               
when(mockEnvironment.getTaskStateManager()).thenReturn(taskStateManager);
+               // mock the operator with empty snapshot result (all state 
handles are null)
+               StreamOperator<?> statelessOperator = 
streamOperatorWithSnapshot(new OperatorSnapshotFutures());
 
-               StreamTask<?, ?> streamTask = new 
EmptyStreamTask(mockEnvironment);
-               CheckpointMetaData checkpointMetaData = new 
CheckpointMetaData(checkpointId, timestamp);
-
-               // mock the operators
-               StreamOperator<?> statelessOperator =
-                               mock(StreamOperator.class);
+               try (MockEnvironment mockEnvironment = new 
MockEnvironmentBuilder()
+                       .setTaskStateManager(taskStateManager)
+                       .build()) {
 
-               final OperatorID operatorID = new OperatorID();
-               when(statelessOperator.getOperatorID()).thenReturn(operatorID);
+                       RunningTask<MockStreamTask> task = runTask(() -> new 
MockStreamTask(
+                               mockEnvironment,
+                               operatorChain(statelessOperator)));
 
-               // mock the returned empty snapshot result (all state handles 
are null)
-               OperatorSnapshotFutures statelessOperatorSnapshotResult = new 
OperatorSnapshotFutures();
-               when(statelessOperator.snapshotState(anyLong(), anyLong(), 
any(CheckpointOptions.class), any(CheckpointStreamFactory.class)))
-                               .thenReturn(statelessOperatorSnapshotResult);
+                       waitTaskIsRunning(task.streamTask, 
task.invocationFuture);
 
-               // set up the task
-               StreamOperator<?>[] streamOperators = {statelessOperator};
-               OperatorChain<Void, AbstractStreamOperator<Void>> operatorChain 
= mock(OperatorChain.class);
-               
when(operatorChain.getAllOperators()).thenReturn(streamOperators);
+                       task.streamTask.triggerCheckpoint(
+                               new CheckpointMetaData(42L, 1L),
+                               
CheckpointOptions.forCheckpointWithDefaultLocation(),
+                               false);
 
-               Whitebox.setInternalState(streamTask, "isRunning", true);
-               Whitebox.setInternalState(streamTask, "lock", new Object());
-               Whitebox.setInternalState(streamTask, "operatorChain", 
operatorChain);
-               Whitebox.setInternalState(streamTask, "cancelables", new 
CloseableRegistry());
-               Whitebox.setInternalState(streamTask, "configuration", new 
StreamConfig(new Configuration()));
-               Whitebox.setInternalState(streamTask, 
"asyncOperationsThreadPool", Executors.newCachedThreadPool());
-               Whitebox.setInternalState(streamTask, "checkpointStorage", new 
MemoryBackendCheckpointStorage(new JobID(), null, null, Integer.MAX_VALUE));
+                       checkpointCompletedLatch.await(30, TimeUnit.SECONDS);
 
-               streamTask.triggerCheckpoint(checkpointMetaData, 
CheckpointOptions.forCheckpointWithDefaultLocation(), false);
-               checkpointCompletedLatch.await(30, TimeUnit.SECONDS);
-               streamTask.cancel();
+                       // ensure that 'null' was acknowledged as subtask state
+                       Assert.assertNull(checkpointResult.get(0));
 
-               // ensure that 'null' was acknowledged as subtask state
-               Assert.assertNull(checkpointResult.get(0));
+                       task.streamTask.cancel();
+                       task.waitForTaskCompletion(true);
+               }
        }
 
        /**
@@ -748,37 +668,21 @@ public class StreamTaskTest extends TestLogger {
                                        .setBufferSize(1)
                                        .setTaskConfiguration(taskConfiguration)
                                        .build()) {
-                       StreamTask<Void, BlockingCloseStreamOperator> 
streamTask = new NoOpStreamTask<>(mockEnvironment);
-                       final AtomicReference<Throwable> atomicThrowable = new 
AtomicReference<>(null);
 
-                       CompletableFuture<Void> invokeFuture = 
CompletableFuture.runAsync(
-                               () -> {
-                                       try {
-                                               streamTask.invoke();
-                                       } catch (Exception e) {
-                                               atomicThrowable.set(e);
-                                       }
-                               },
-                               TestingUtils.defaultExecutor());
+                       RunningTask<StreamTask<Void, 
BlockingCloseStreamOperator>> task = runTask(() -> new 
NoOpStreamTask<>(mockEnvironment));
 
                        BlockingCloseStreamOperator.IN_CLOSE.await();
 
                        // check that the StreamTask is not yet in isRunning == 
false
-                       assertTrue(streamTask.isRunning());
+                       assertTrue(task.streamTask.isRunning());
 
                        // let the operator finish its close operation
                        BlockingCloseStreamOperator.FINISH_CLOSE.trigger();
 
-                       // wait until the invoke is complete
-                       invokeFuture.get();
+                       task.waitForTaskCompletion(false);
 
                        // now the StreamTask should no longer be running
-                       assertFalse(streamTask.isRunning());
-
-                       // check if an exception occurred
-                       if (atomicThrowable.get() != null) {
-                               throw atomicThrowable.get();
-                       }
+                       assertFalse(task.streamTask.isRunning());
                }
        }
 
@@ -793,22 +697,11 @@ public class StreamTaskTest extends TestLogger {
                        new MockEnvironmentBuilder()
                                .setUserCodeClassLoader(new 
TestUserCodeClassLoader())
                                .build()) {
-                       TimeServiceTask timerServiceTask = new 
TimeServiceTask(mockEnvironment);
+                       RunningTask<TimeServiceTask> task = runTask(() -> new 
TimeServiceTask(mockEnvironment));
+                       task.waitForTaskCompletion(false);
 
-                       CompletableFuture<Void> invokeFuture = 
CompletableFuture.runAsync(
-                               () -> {
-                                       try {
-                                               timerServiceTask.invoke();
-                                       } catch (Exception e) {
-                                               throw new 
CompletionException(e);
-                                       }
-                               },
-                               TestingUtils.defaultExecutor());
-
-                       invokeFuture.get();
-
-                       assertThat(timerServiceTask.getClassLoaders(), 
hasSize(greaterThanOrEqualTo(1)));
-                       assertThat(timerServiceTask.getClassLoaders(), 
everyItem(instanceOf(TestUserCodeClassLoader.class)));
+                       assertThat(task.streamTask.getClassLoaders(), 
hasSize(greaterThanOrEqualTo(1)));
+                       assertThat(task.streamTask.getClassLoaders(), 
everyItem(instanceOf(TestUserCodeClassLoader.class)));
                }
        }
 
@@ -816,6 +709,80 @@ public class StreamTaskTest extends TestLogger {
        //  Test Utilities
        // 
------------------------------------------------------------------------
 
+       private static StreamOperator<?> 
streamOperatorWithSnapshot(OperatorSnapshotFutures operatorSnapshotResult) 
throws Exception {
+               StreamOperator<?> operator = mock(StreamOperator.class);
+               when(operator.getOperatorID()).thenReturn(new OperatorID());
+
+               when(operator.snapshotState(anyLong(), anyLong(), 
any(CheckpointOptions.class), any(CheckpointStreamFactory.class)))
+                       .thenReturn(operatorSnapshotResult);
+
+               return operator;
+       }
+
+       private static StreamOperator<?> 
streamOperatorWithSnapshotException(Exception exception) throws Exception {
+               StreamOperator<?> operator = mock(StreamOperator.class);
+               when(operator.getOperatorID()).thenReturn(new OperatorID());
+
+               when(operator.snapshotState(anyLong(), anyLong(), 
any(CheckpointOptions.class), any(CheckpointStreamFactory.class)))
+                       .thenThrow(exception);
+
+               return operator;
+       }
+
+       private static <T> OperatorChain<T, AbstractStreamOperator<T>> 
operatorChain(StreamOperator<?>... streamOperators) {
+               OperatorChain<T, AbstractStreamOperator<T>> operatorChain = 
mock(OperatorChain.class);
+               
when(operatorChain.getAllOperators()).thenReturn(streamOperators);
+               return operatorChain;
+       }
+
+       private static class RunningTask<T extends StreamTask<?, ?>> {
+               final T streamTask;
+               final CompletableFuture<Void> invocationFuture;
+
+               RunningTask(T streamTask, CompletableFuture<Void> 
invocationFuture) {
+                       this.streamTask = streamTask;
+                       this.invocationFuture = invocationFuture;
+               }
+
+               void waitForTaskCompletion(boolean cancelled) throws Exception {
+                       try {
+                               invocationFuture.get();
+                       } catch (Exception e) {
+                               if (cancelled) {
+                                       assertThat(e.getCause(), 
is(instanceOf(CancelTaskException.class)));
+                               } else {
+                                       throw e;
+                               }
+                       }
+                       assertThat(streamTask.isCanceled(), is(cancelled));
+               }
+       }
+
+       private static <T extends StreamTask<?, ?>> RunningTask<T> 
runTask(SupplierWithException<T, Exception> taskFactory) throws Exception {
+               CompletableFuture<T> taskCreationFuture = new 
CompletableFuture<>();
+               CompletableFuture<Void> invocationFuture = 
CompletableFuture.runAsync(
+                       () -> {
+                               T task;
+                               try {
+                                       task = taskFactory.get();
+                                       taskCreationFuture.complete(task);
+                               } catch (Exception e) {
+                                       
taskCreationFuture.completeExceptionally(e);
+                                       return;
+                               }
+                               try {
+                                       task.invoke();
+                               } catch (RuntimeException e) {
+                                       throw e;
+                               } catch (Exception e) {
+                                       throw new RuntimeException(e);
+                               }
+                       }, Executors.newSingleThreadExecutor());
+
+               // Wait until task is created.
+               return new RunningTask<>(taskCreationFuture.get(), 
invocationFuture);
+       }
+
        /**
         * Operator that does nothing.
         *
@@ -885,7 +852,6 @@ public class StreamTaskTest extends TestLogger {
                LibraryCacheManager libCache = mock(LibraryCacheManager.class);
                
when(libCache.getClassLoader(any(JobID.class))).thenReturn(StreamTaskTest.class.getClassLoader());
 
-               ResultPartitionManager partitionManager = 
mock(ResultPartitionManager.class);
                ResultPartitionConsumableNotifier consumableNotifier = new 
NoOpResultPartitionConsumableNotifier();
                PartitionProducerStateChecker partitionProducerStateChecker = 
mock(PartitionProducerStateChecker.class);
                Executor executor = mock(Executor.class);
@@ -1012,18 +978,30 @@ public class StreamTaskTest extends TestLogger {
        // 
------------------------------------------------------------------------
        // 
------------------------------------------------------------------------
 
-       private static class EmptyStreamTask extends StreamTask<String, 
AbstractStreamOperator<String>> {
+       private static class MockStreamTask extends StreamTask<String, 
AbstractStreamOperator<String>> {
 
-               public EmptyStreamTask(Environment env) {
+               private final OperatorChain<String, 
AbstractStreamOperator<String>> overrideOperatorChain;
+               private volatile boolean inputFinished;
+
+               MockStreamTask(Environment env, OperatorChain<String, 
AbstractStreamOperator<String>> operatorChain) {
                        super(env, null);
+                       this.overrideOperatorChain = operatorChain;
                }
 
                @Override
-               protected void init() throws Exception {}
+               protected void init() {
+                       // The StreamTask initializes operatorChain first on 
it's own in `invoke()` method.
+                       // Later it calls the `init()` method before actual 
`run()`, so we are overriding the operatorChain
+                       // here for test purposes.
+                       super.operatorChain = this.overrideOperatorChain;
+                       super.headOperator = 
super.operatorChain.getHeadOperator();
+               }
 
                @Override
-               protected void performDefaultAction(ActionContext context) 
throws Exception {
-                       context.allActionsCompleted();
+               protected void performDefaultAction(ActionContext context) {
+                       if (isCanceled() || inputFinished) {
+                               context.allActionsCompleted();
+                       }
                }
 
                @Override
@@ -1031,6 +1009,10 @@ public class StreamTaskTest extends TestLogger {
 
                @Override
                protected void cancelTask() throws Exception {}
+
+               void finishInput() {
+                       this.inputFinished = true;
+               }
        }
 
        /**
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/StreamTaskUtil.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/StreamTaskUtil.java
new file mode 100644
index 0000000..17c61ac
--- /dev/null
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/StreamTaskUtil.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.util;
+
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+
+import java.util.concurrent.CompletableFuture;
+
+import static org.junit.Assert.fail;
+
+/**
+ * Utils for working with StreamTask.
+ */
+public class StreamTaskUtil {
+
+       public static void waitTaskIsRunning(StreamTask<?, ?> task, 
CompletableFuture<Void> taskInvocation) throws InterruptedException {
+               while (!task.isRunning()) {
+                       if (taskInvocation.isDone()) {
+                               fail("Task has stopped");
+                       }
+                       Thread.sleep(10L);
+               }
+       }
+}

Reply via email to