Repository: flink
Updated Branches:
  refs/heads/release-1.4 005a87177 -> 2117eb77b


[FLINK-8005] Set user-code class loader as context loader before snapshot

During checkpointing, user code may dynamically load classes from the user code
jar. This is a problem if the thread invoking the snapshot callbacks does not
have the user code class loader set as its context class loader. This commit
makes sure that the correct class loader is set.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/2117eb77
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/2117eb77
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/2117eb77

Branch: refs/heads/release-1.4
Commit: 2117eb77bb9d34da4288b5dd4455ef06c583ce7c
Parents: 005a871
Author: gyao <g...@data-artisans.com>
Authored: Wed Nov 8 11:46:45 2017 +0100
Committer: Aljoscha Krettek <aljoscha.kret...@gmail.com>
Committed: Fri Nov 10 09:26:37 2017 +0100

----------------------------------------------------------------------
 .../taskmanager/DispatcherThreadFactory.java    |  24 ++-
 .../apache/flink/runtime/taskmanager/Task.java  |  22 +-
 .../runtime/taskmanager/TaskAsyncCallTest.java  | 206 +++++++++++++++----
 .../flink/runtime/taskmanager/TaskStopTest.java | 157 --------------
 4 files changed, 204 insertions(+), 205 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/2117eb77/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/DispatcherThreadFactory.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/DispatcherThreadFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/DispatcherThreadFactory.java
index 97060a8..543b159 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/DispatcherThreadFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/DispatcherThreadFactory.java
@@ -18,6 +18,8 @@
 
 package org.apache.flink.runtime.taskmanager;
 
+import javax.annotation.Nullable;
+
 import java.util.concurrent.ThreadFactory;
 
 /**
@@ -29,21 +31,41 @@ public class DispatcherThreadFactory implements 
ThreadFactory {
        private final ThreadGroup group;
        
        private final String threadName;
+
+       private final ClassLoader classLoader;
        
        /**
         * Creates a new thread factory.
-        * 
+        *
         * @param group The group that the threads will be associated with.
         * @param threadName The name for the threads.
         */
        public DispatcherThreadFactory(ThreadGroup group, String threadName) {
+               this(group, threadName, null);
+       }
+
+       /**
+        * Creates a new thread factory.
+        *
+        * @param group The group that the threads will be associated with.
+        * @param threadName The name for the threads.
+        * @param classLoader The {@link ClassLoader} to be set as context 
class loader.
+        */
+       public DispatcherThreadFactory(
+                       ThreadGroup group,
+                       String threadName,
+                       @Nullable ClassLoader classLoader) {
                this.group = group;
                this.threadName = threadName;
+               this.classLoader = classLoader;
        }
 
        @Override
        public Thread newThread(Runnable r) {
                Thread t = new Thread(group, r, threadName);
+               if (classLoader != null) {
+                       t.setContextClassLoader(classLoader);
+               }
                t.setDaemon(true);
                return t;
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/2117eb77/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java 
b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
index 58dd9e3..2cb356c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
@@ -99,6 +99,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
 
 /**
  * The Task represents one execution of a parallel subtask on a TaskManager.
@@ -265,6 +266,12 @@ public class Task implements Runnable, TaskActions {
        private long taskCancellationTimeout;
 
        /**
+        * This class loader should be set as the context class loader of the 
threads in
+        * {@link #asyncCallDispatcher} because user code may dynamically load 
classes in all callbacks.
+        */
+       private ClassLoader userCodeClassLoader;
+
+       /**
         * <p><b>IMPORTANT:</b> This constructor may not start any work that 
would need to
         * be undone in the case of a failing task deployment.</p>
         */
@@ -563,7 +570,6 @@ public class Task implements Runnable, TaskActions {
                Map<String, Future<Path>> distributedCacheEntries = new 
HashMap<String, Future<Path>>();
                AbstractInvokable invokable = null;
 
-               ClassLoader userCodeClassLoader;
                try {
                        // ----------------------------
                        //  Task Bootstrap - We periodically
@@ -580,7 +586,7 @@ public class Task implements Runnable, TaskActions {
                        // this may involve downloading the job's JAR files 
and/or classes
                        LOG.info("Loading JAR files for task {}.", this);
 
-                       userCodeClassLoader = 
createUserCodeClassloader(libraryCache);
+                       userCodeClassLoader = createUserCodeClassloader();
                        final ExecutionConfig executionConfig = 
serializedExecutionConfig.deserializeValue(userCodeClassLoader);
 
                        if (executionConfig.getTaskCancellationInterval() >= 0) 
{
@@ -865,7 +871,7 @@ public class Task implements Runnable, TaskActions {
                }
        }
 
-       private ClassLoader createUserCodeClassloader(LibraryCacheManager 
libraryCache) throws Exception {
+       private ClassLoader createUserCodeClassloader() throws Exception {
                long startDownloadTime = System.currentTimeMillis();
 
                // triggers the download of all missing jar files from the job 
manager
@@ -1342,15 +1348,19 @@ public class Task implements Runnable, TaskActions {
                        if (executionState != ExecutionState.RUNNING) {
                                return;
                        }
-                       
+
                        // get ourselves a reference on the stack that cannot 
be concurrently modified
                        ExecutorService executor = this.asyncCallDispatcher;
                        if (executor == null) {
                                // first time use, initialize
+                               checkState(userCodeClassLoader != null, 
"userCodeClassLoader must not be null");
                                executor = Executors.newSingleThreadExecutor(
-                                               new 
DispatcherThreadFactory(TASK_THREADS_GROUP, "Async calls on " + 
taskNameWithSubtask));
+                                               new DispatcherThreadFactory(
+                                                       TASK_THREADS_GROUP,
+                                                       "Async calls on " + 
taskNameWithSubtask,
+                                                       userCodeClassLoader));
                                this.asyncCallDispatcher = executor;
-                               
+
                                // double-check for execution state, and make 
sure we clean up after ourselves
                                // if we created the dispatcher while the task 
was concurrently canceled
                                if (executionState != ExecutionState.RUNNING) {

http://git-wip-us.apache.org/repos/asf/flink/blob/2117eb77/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
index d925e4d..5045606 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
@@ -48,6 +48,7 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.jobgraph.tasks.StatefulTask;
+import org.apache.flink.runtime.jobgraph.tasks.StoppableTask;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
@@ -58,10 +59,19 @@ import org.apache.flink.util.SerializedValue;
 import org.junit.Before;
 import org.junit.Test;
 
+import java.util.ArrayList;
 import java.util.Collections;
+import java.util.List;
 import java.util.concurrent.Executor;
 
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.everyItem;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.isOneOf;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
 import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.mock;
@@ -69,61 +79,76 @@ import static org.mockito.Mockito.when;
 
 public class TaskAsyncCallTest {
 
-       private static final int NUM_CALLS = 1000;
-       
+       /** Number of expected checkpoints. */
+       private static int numCalls;
+
+       /** Triggered at the beginning of {@link 
CheckpointsInOrderInvokable#invoke()}. */
        private static OneShotLatch awaitLatch;
+
+       /**
+        * Triggered when {@link 
CheckpointsInOrderInvokable#triggerCheckpoint(CheckpointMetaData, 
CheckpointOptions)}
+        * was called {@link #numCalls} times.
+        */
        private static OneShotLatch triggerLatch;
 
+       /**
+        * Triggered when {@link 
CheckpointsInOrderInvokable#notifyCheckpointComplete(long)}
+        * was called {@link #numCalls} times.
+        */
+       private static OneShotLatch notifyCheckpointCompleteLatch;
+
+       /** Triggered on {@link 
ContextClassLoaderInterceptingInvokable#stop()}}. */
+       private static OneShotLatch stopLatch;
+
+       private static final List<ClassLoader> classLoaders = 
Collections.synchronizedList(new ArrayList<>());
+
        @Before
        public void createQueuesAndActors() {
+               numCalls = 1000;
+
                awaitLatch = new OneShotLatch();
                triggerLatch = new OneShotLatch();
+               notifyCheckpointCompleteLatch = new OneShotLatch();
+               stopLatch = new OneShotLatch();
+
+               classLoaders.clear();
        }
 
 
        // 
------------------------------------------------------------------------
        //  Tests 
        // 
------------------------------------------------------------------------
-       
+
        @Test
-       public void testCheckpointCallsInOrder() {
-               try {
-                       Task task = createTask();
+       public void testCheckpointCallsInOrder() throws Exception {
+               Task task = createTask(CheckpointsInOrderInvokable.class);
+               try (TaskCleaner ignored = new TaskCleaner(task)) {
                        task.startTaskThread();
-                       
+
                        awaitLatch.await();
-                       
-                       for (int i = 1; i <= NUM_CALLS; i++) {
+
+                       for (int i = 1; i <= numCalls; i++) {
                                task.triggerCheckpointBarrier(i, 156865867234L, 
CheckpointOptions.forCheckpoint());
                        }
-                       
+
                        triggerLatch.await();
-                       
+
                        assertFalse(task.isCanceledOrFailed());
 
                        ExecutionState currentState = task.getExecutionState();
-                       if (currentState != ExecutionState.RUNNING && 
currentState != ExecutionState.FINISHED) {
-                               fail("Task should be RUNNING or FINISHED, but 
is " + currentState);
-                       }
-                       
-                       task.cancelExecution();
-                       task.getExecutingThread().join();
-               }
-               catch (Exception e) {
-                       e.printStackTrace();
-                       fail(e.getMessage());
+                       assertThat(currentState, 
isOneOf(ExecutionState.RUNNING, ExecutionState.FINISHED));
                }
        }
 
        @Test
-       public void testMixedAsyncCallsInOrder() {
-               try {
-                       Task task = createTask();
+       public void testMixedAsyncCallsInOrder() throws Exception {
+               Task task = createTask(CheckpointsInOrderInvokable.class);
+               try (TaskCleaner ignored = new TaskCleaner(task)) {
                        task.startTaskThread();
 
                        awaitLatch.await();
 
-                       for (int i = 1; i <= NUM_CALLS; i++) {
+                       for (int i = 1; i <= numCalls; i++) {
                                task.triggerCheckpointBarrier(i, 156865867234L, 
CheckpointOptions.forCheckpoint());
                                task.notifyCheckpointComplete(i);
                        }
@@ -131,26 +156,62 @@ public class TaskAsyncCallTest {
                        triggerLatch.await();
 
                        assertFalse(task.isCanceledOrFailed());
+
                        ExecutionState currentState = task.getExecutionState();
-                       if (currentState != ExecutionState.RUNNING && 
currentState != ExecutionState.FINISHED) {
-                               fail("Task should be RUNNING or FINISHED, but 
is " + currentState);
-                       }
+                       assertThat(currentState, 
isOneOf(ExecutionState.RUNNING, ExecutionState.FINISHED));
+               }
+       }
 
-                       task.cancelExecution();
-                       task.getExecutingThread().join();
+       @Test
+       public void testThrowExceptionIfStopInvokedWithNotStoppableTask() 
throws Exception {
+               Task task = createTask(CheckpointsInOrderInvokable.class);
+               try (TaskCleaner ignored = new TaskCleaner(task)) {
+                       task.startTaskThread();
+                       awaitLatch.await();
+
+                       try {
+                               task.stopExecution();
+                               fail("Expected exception not thrown");
+                       } catch (UnsupportedOperationException e) {
+                               assertThat(e.getMessage(), 
containsString("Stopping not supported by task"));
+                       }
                }
-               catch (Exception e) {
-                       e.printStackTrace();
-                       fail(e.getMessage());
+       }
+
+       /**
+        * Asserts that {@link 
StatefulTask#triggerCheckpoint(CheckpointMetaData, CheckpointOptions)},
+        * {@link StatefulTask#notifyCheckpointComplete(long)}, and {@link 
StoppableTask#stop()} are
+        * invoked by a thread whose context class loader is set to the user 
code class loader.
+        */
+       @Test
+       public void testSetsUserCodeClassLoader() throws Exception {
+               numCalls = 1;
+
+               Task task = 
createTask(ContextClassLoaderInterceptingInvokable.class);
+               try (TaskCleaner ignored = new TaskCleaner(task)) {
+                       task.startTaskThread();
+
+                       awaitLatch.await();
+
+                       task.triggerCheckpointBarrier(1, 1, 
CheckpointOptions.forCheckpoint());
+                       task.notifyCheckpointComplete(1);
+                       task.stopExecution();
+
+                       triggerLatch.await();
+                       notifyCheckpointCompleteLatch.await();
+                       stopLatch.await();
+
+                       assertThat(classLoaders, 
hasSize(greaterThanOrEqualTo(3)));
+                       assertThat(classLoaders, 
everyItem(instanceOf(TestUserCodeClassLoader.class)));
                }
        }
-       
-       private static Task createTask() throws Exception {
+
+       private Task createTask(Class<? extends AbstractInvokable> 
invokableClass) throws Exception {
                BlobCacheService blobService =
                        new BlobCacheService(mock(PermanentBlobCache.class), 
mock(TransientBlobCache.class));
 
                LibraryCacheManager libCache = mock(LibraryCacheManager.class);
-               
when(libCache.getClassLoader(any(JobID.class))).thenReturn(ClassLoader.getSystemClassLoader());
+               when(libCache.getClassLoader(any(JobID.class))).thenReturn(new 
TestUserCodeClassLoader());
                
                ResultPartitionManager partitionManager = 
mock(ResultPartitionManager.class);
                ResultPartitionConsumableNotifier consumableNotifier = 
mock(ResultPartitionConsumableNotifier.class);
@@ -178,7 +239,7 @@ public class TaskAsyncCallTest {
                        "Test Task",
                        1,
                        1,
-                       CheckpointsInOrderInvokable.class.getName(),
+                       invokableClass.getName(),
                        new Configuration());
 
                return new Task(
@@ -221,13 +282,17 @@ public class TaskAsyncCallTest {
                        
                        // wait forever (until canceled)
                        synchronized (this) {
-                               while (error == null && lastCheckpointId < 
NUM_CALLS) {
+                               while (error == null && lastCheckpointId < 
numCalls) {
                                        wait();
                                }
                        }
-                       
-                       triggerLatch.trigger();
+
                        if (error != null) {
+                               // exit method prematurely due to error but 
make sure that the tests can finish
+                               triggerLatch.trigger();
+                               notifyCheckpointCompleteLatch.trigger();
+                               stopLatch.trigger();
+
                                throw error;
                        }
                }
@@ -239,7 +304,7 @@ public class TaskAsyncCallTest {
                public boolean triggerCheckpoint(CheckpointMetaData 
checkpointMetaData, CheckpointOptions checkpointOptions) {
                        lastCheckpointId++;
                        if (checkpointMetaData.getCheckpointId() == 
lastCheckpointId) {
-                               if (lastCheckpointId == NUM_CALLS) {
+                               if (lastCheckpointId == numCalls) {
                                        triggerLatch.trigger();
                                }
                        }
@@ -269,7 +334,66 @@ public class TaskAsyncCallTest {
                                synchronized (this) {
                                        notifyAll();
                                }
+                       } else if (lastCheckpointId == numCalls) {
+                               notifyCheckpointCompleteLatch.trigger();
                        }
                }
        }
+
+       /**
+        * This is an {@link AbstractInvokable} that stores the context class 
loader of the invoking
+        * thread in a static field so that tests can assert on the class 
loader instances.
+        *
+        * @see #testSetsUserCodeClassLoader()
+        */
+       public static class ContextClassLoaderInterceptingInvokable extends 
CheckpointsInOrderInvokable implements StoppableTask {
+
+               @Override
+               public boolean triggerCheckpoint(CheckpointMetaData 
checkpointMetaData, CheckpointOptions checkpointOptions) {
+                       
classLoaders.add(Thread.currentThread().getContextClassLoader());
+
+                       return super.triggerCheckpoint(checkpointMetaData, 
checkpointOptions);
+               }
+
+               @Override
+               public void notifyCheckpointComplete(long checkpointId) {
+                       
classLoaders.add(Thread.currentThread().getContextClassLoader());
+
+                       super.notifyCheckpointComplete(checkpointId);
+               }
+
+               @Override
+               public void stop() {
+                       
classLoaders.add(Thread.currentThread().getContextClassLoader());
+                       stopLatch.trigger();
+               }
+
+       }
+
+       /**
+        * A {@link ClassLoader} that delegates everything to {@link 
ClassLoader#getSystemClassLoader()}.
+        *
+        * @see #testSetsUserCodeClassLoader()
+        */
+       private static class TestUserCodeClassLoader extends ClassLoader {
+               public TestUserCodeClassLoader() {
+                       super(ClassLoader.getSystemClassLoader());
+               }
+       }
+
+       private static class TaskCleaner implements AutoCloseable {
+
+               private final Task task;
+
+               private TaskCleaner(Task task) {
+                       this.task = task;
+               }
+
+               @Override
+               public void close() throws Exception {
+                       task.cancelExecution();
+                       task.getExecutingThread().join(5000);
+               }
+       }
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/2117eb77/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java
deleted file mode 100644
index d062def..0000000
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java
+++ /dev/null
@@ -1,157 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.runtime.taskmanager;
-
-import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.TaskInfo;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.runtime.blob.BlobCacheService;
-import org.apache.flink.runtime.blob.PermanentBlobCache;
-import org.apache.flink.runtime.blob.TransientBlobCache;
-import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
-import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
-import org.apache.flink.runtime.clusterframework.types.AllocationID;
-import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
-import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
-import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
-import org.apache.flink.runtime.execution.ExecutionState;
-import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager;
-import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
-import org.apache.flink.runtime.executiongraph.JobInformation;
-import org.apache.flink.runtime.executiongraph.TaskInformation;
-import org.apache.flink.runtime.filecache.FileCache;
-import org.apache.flink.runtime.io.disk.iomanager.IOManager;
-import org.apache.flink.runtime.io.network.NetworkEnvironment;
-import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker;
-import 
org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
-import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
-import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
-import org.apache.flink.runtime.jobgraph.tasks.StoppableTask;
-import org.apache.flink.runtime.memory.MemoryManager;
-import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
-import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
-
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
-
-import java.lang.reflect.Field;
-import java.util.Collections;
-import java.util.concurrent.Executor;
-
-import scala.concurrent.duration.FiniteDuration;
-
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-
-@RunWith(PowerMockRunner.class)
-@PrepareForTest({ TaskDeploymentDescriptor.class, JobID.class, 
FiniteDuration.class })
-public class TaskStopTest {
-       private Task task;
-
-       public void doMocking(AbstractInvokable taskMock) throws Exception {
-
-               TaskInfo taskInfoMock = mock(TaskInfo.class);
-               
when(taskInfoMock.getTaskNameWithSubtasks()).thenReturn("dummyName");
-
-               TaskManagerRuntimeInfo tmRuntimeInfo = 
mock(TaskManagerRuntimeInfo.class);
-               when(tmRuntimeInfo.getConfiguration()).thenReturn(new 
Configuration());
-
-               TaskMetricGroup taskMetricGroup = mock(TaskMetricGroup.class);
-               
when(taskMetricGroup.getIOMetricGroup()).thenReturn(mock(TaskIOMetricGroup.class));
-
-               BlobCacheService blobService =
-                       new BlobCacheService(mock(PermanentBlobCache.class), 
mock(TransientBlobCache.class));
-
-               task = new Task(
-                       mock(JobInformation.class),
-                       new TaskInformation(
-                               new JobVertexID(),
-                               "test task name",
-                               1,
-                               1,
-                               "foobar",
-                               new Configuration()),
-                       mock(ExecutionAttemptID.class),
-                       mock(AllocationID.class),
-                       0,
-                       0,
-                       
Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
-                       Collections.<InputGateDeploymentDescriptor>emptyList(),
-                       0,
-                       mock(TaskStateSnapshot.class),
-                       mock(MemoryManager.class),
-                       mock(IOManager.class),
-                       mock(NetworkEnvironment.class),
-                       mock(BroadcastVariableManager.class),
-                       mock(TaskManagerActions.class),
-                       mock(InputSplitProvider.class),
-                       mock(CheckpointResponder.class),
-                       blobService,
-                       mock(LibraryCacheManager.class),
-                       mock(FileCache.class),
-                       tmRuntimeInfo,
-                       taskMetricGroup,
-                       mock(ResultPartitionConsumableNotifier.class),
-                       mock(PartitionProducerStateChecker.class),
-                       mock(Executor.class));
-               Field f = task.getClass().getDeclaredField("invokable");
-               f.setAccessible(true);
-               f.set(task, taskMock);
-
-               Field f2 = task.getClass().getDeclaredField("executionState");
-               f2.setAccessible(true);
-               f2.set(task, ExecutionState.RUNNING);
-       }
-
-       @Test(timeout = 20000)
-       public void testStopExecution() throws Exception {
-               StoppableTestTask taskMock = new StoppableTestTask();
-               doMocking(taskMock);
-
-               task.stopExecution();
-
-               while (!taskMock.stopCalled) {
-                       Thread.sleep(100);
-               }
-       }
-
-       @Test(expected = RuntimeException.class)
-       public void testStopExecutionFail() throws Exception {
-               AbstractInvokable taskMock = mock(AbstractInvokable.class);
-               doMocking(taskMock);
-
-               task.stopExecution();
-       }
-
-       private final static class StoppableTestTask extends AbstractInvokable 
implements StoppableTask {
-               public volatile boolean stopCalled = false;
-
-               @Override
-               public void invoke() throws Exception {
-               }
-
-               @Override
-               public void stop() {
-                       this.stopCalled = true;
-               }
-       }
-
-}

Reply via email to