Repository: hive
Updated Branches:
  refs/heads/llap 53b0cb750 -> af7bf5754


HIVE-11273. LLAP: Register for finishable state change notifications when 
adding a task instead of when scheduling it. (Siddharth Seth)


Project: http://git-wip-us.apache.org/repos/asf/hive/repo
Commit: http://git-wip-us.apache.org/repos/asf/hive/commit/f47810cb
Tree: http://git-wip-us.apache.org/repos/asf/hive/tree/f47810cb
Diff: http://git-wip-us.apache.org/repos/asf/hive/diff/f47810cb

Branch: refs/heads/llap
Commit: f47810cbbe87cbf1888ce2e096a167c27613a905
Parents: 53b0cb7
Author: Siddharth Seth <ss...@apache.org>
Authored: Fri Jul 17 08:46:15 2015 -0700
Committer: Siddharth Seth <ss...@apache.org>
Committed: Fri Jul 17 08:46:15 2015 -0700

----------------------------------------------------------------------
 .../llap/daemon/impl/TaskExecutorService.java   |  52 ++-
 .../daemon/impl/TestTaskExecutorService.java    | 419 ++++++++++++++-----
 2 files changed, 358 insertions(+), 113 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/hive/blob/f47810cb/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskExecutorService.java
----------------------------------------------------------------------
diff --git 
a/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskExecutorService.java
 
b/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskExecutorService.java
index f083a48..e6cf151 100644
--- 
a/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskExecutorService.java
+++ 
b/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskExecutorService.java
@@ -95,8 +95,10 @@ public class TaskExecutorService extends AbstractService 
implements Scheduler<Ta
   private final ThreadPoolExecutor threadPoolExecutor;
   private final AtomicInteger numSlotsAvailable;
 
+
+  @VisibleForTesting
   // Tracks known tasks.
-  private final ConcurrentMap<String, TaskWrapper> knownTasks = new 
ConcurrentHashMap<>();
+  final ConcurrentMap<String, TaskWrapper> knownTasks = new 
ConcurrentHashMap<>();
 
   private final Object lock = new Object();
 
@@ -219,9 +221,9 @@ public class TaskExecutorService extends AbstractService 
implements Scheduler<Ta
             if (task.getTaskRunnerCallable().canFinish()) {
               if (isDebugEnabled) {
                 LOG.debug(
-                    "Attempting to schedule task {}, canFinish={}. Current 
state: preemptionQueueSize={}, numSlotsAvailable={}",
+                    "Attempting to schedule task {}, canFinish={}. Current 
state: preemptionQueueSize={}, numSlotsAvailable={}, waitQueueSize={}",
                     task.getRequestId(), 
task.getTaskRunnerCallable().canFinish(),
-                    preemptionQueue.size(), numSlotsAvailable.get());
+                    preemptionQueue.size(), numSlotsAvailable.get(), 
waitQueue.size());
               }
               if (numSlotsAvailable.get() == 0 && preemptionQueue.isEmpty()) {
                 shouldWait = true;
@@ -294,12 +296,23 @@ public class TaskExecutorService extends AbstractService 
implements Scheduler<Ta
   public void schedule(TaskRunnerCallable task) throws 
RejectedExecutionException {
     TaskWrapper taskWrapper = new TaskWrapper(task, this);
     knownTasks.put(taskWrapper.getRequestId(), taskWrapper);
+
+    // Register for state change notifications so that the waitQueue can be 
re-ordered correctly
+    // if the fragment moves in or out of the finishable state.
+    boolean canFinish = taskWrapper.getTaskRunnerCallable().canFinish();
+    // It's safe to register outside of the lock since the stateChangeTracker 
ensures that updates
+    // and registrations are mutually exclusive.
+    taskWrapper.maybeRegisterForFinishedStateNotifications(canFinish);
+
     TaskWrapper evictedTask;
     try {
-      // Don't need a lock. Not subscribed for notifications yet, and marked 
as inWaitQueue
-      evictedTask = waitQueue.offer(taskWrapper);
+      synchronized (lock) {
+        evictedTask = waitQueue.offer(taskWrapper);
+        taskWrapper.setIsInWaitQueue(true);
+      }
     } catch (RejectedExecutionException e) {
       knownTasks.remove(taskWrapper.getRequestId());
+      taskWrapper.maybeUnregisterForFinishedStateNotifications();
       throw e;
     }
     if (isInfoEnabled) {
@@ -310,6 +323,7 @@ public class TaskExecutorService extends AbstractService 
implements Scheduler<Ta
     }
     if (evictedTask != null) {
       evictedTask.maybeUnregisterForFinishedStateNotifications();
+      evictedTask.setIsInWaitQueue(false);
       evictedTask.getTaskRunnerCallable().killTask();
       if (isInfoEnabled) {
         LOG.info("{} evicted from wait queue in favor of {} because of lower 
priority",
@@ -353,15 +367,11 @@ public class TaskExecutorService extends AbstractService 
implements Scheduler<Ta
 
     boolean scheduled = false;
     try {
-
-      boolean canFinish = taskWrapper.getTaskRunnerCallable().canFinish();
-      // It's safe to register outside of the lock since the 
stateChangeTracker ensures that updates
-      // and registrations are mutually exclusive.
-      boolean stateChanged = 
!taskWrapper.maybeRegisterForFinishedStateNotifications(canFinish);
       synchronized (lock) {
+        boolean canFinish = taskWrapper.getTaskRunnerCallable().canFinish();
         ListenableFuture<TaskRunner2Result> future = 
executorService.submit(taskWrapper.getTaskRunnerCallable());
         taskWrapper.setIsInWaitQueue(false);
-        FutureCallback<TaskRunner2Result> wrappedCallback = new 
InternalCompletionListener(taskWrapper);
+        FutureCallback<TaskRunner2Result> wrappedCallback = 
createInternalCompletionListener(taskWrapper);
         // Callback on a separate thread so that when a task completes, the 
thread in the main queue
         // is actually available for execution and will not potentially result 
in a RejectedExecution
         Futures.addCallback(future, wrappedCallback, 
executionCompletionExecutorService);
@@ -373,9 +383,10 @@ public class TaskExecutorService extends AbstractService 
implements Scheduler<Ta
         // only tasks that cannot finish immediately are pre-emptable. In 
other words, if all inputs
         // to the tasks are not ready yet, the task is eligible for 
pre-emptable.
         if (enablePreemption) {
-          if ((!canFinish && !stateChanged) || (canFinish && stateChanged)) {
+          if (!canFinish) {
             if (isInfoEnabled) {
-              LOG.info("{} is not finishable. Adding it to pre-emption queue", 
taskWrapper.getRequestId());
+              LOG.info("{} is not finishable. Adding it to pre-emption queue",
+                  taskWrapper.getRequestId());
             }
             addToPreemptionQueue(taskWrapper);
           }
@@ -422,7 +433,7 @@ public class TaskExecutorService extends AbstractService 
implements Scheduler<Ta
         LOG.info("DEBUG: Re-ordering the wait queue since {} finishable state 
moved to {}",
             taskWrapper.getRequestId(), newFinishableState);
         if (waitQueue.remove(taskWrapper)) {
-          // Put element back onlt if it existed.
+          // Put element back only if it existed.
           waitQueue.offer(taskWrapper);
         } else {
           LOG.warn("Failed to remove {} from waitQueue",
@@ -462,7 +473,13 @@ public class TaskExecutorService extends AbstractService 
implements Scheduler<Ta
     return taskWrapper;
   }
 
-  private final class InternalCompletionListener implements
+  @VisibleForTesting
+  InternalCompletionListener createInternalCompletionListener(TaskWrapper 
taskWrapper) {
+    return new InternalCompletionListener(taskWrapper);
+  }
+
+  @VisibleForTesting
+  class InternalCompletionListener implements
       FutureCallback<TaskRunner2Result> {
     private final TaskWrapper taskWrapper;
 
@@ -640,7 +657,7 @@ public class TaskExecutorService extends AbstractService 
implements Scheduler<Ta
 
   public static class TaskWrapper implements FinishableStateUpdateHandler {
     private final TaskRunnerCallable taskRunnerCallable;
-    private boolean inWaitQueue = true;
+    private boolean inWaitQueue = false;
     private boolean inPreemptionQueue = false;
     private boolean registeredForNotifications = false;
     private final TaskExecutorService taskExecutorService;
@@ -709,6 +726,9 @@ public class TaskExecutorService extends AbstractService 
implements Scheduler<Ta
           ", inWaitQueue=" + inWaitQueue +
           ", inPreemptionQueue=" + inPreemptionQueue +
           ", registeredForNotifications=" + registeredForNotifications +
+          ", canFinish=" + taskRunnerCallable.canFinish() +
+          ", firstAttemptStartTime=" + 
taskRunnerCallable.getFirstAttemptStartTime() +
+          ", vertexParallelism=" + taskRunnerCallable.getVertexParallelism() +
           '}';
     }
 

http://git-wip-us.apache.org/repos/asf/hive/blob/f47810cb/llap-server/src/test/org/apache/hadoop/hive/llap/daemon/impl/TestTaskExecutorService.java
----------------------------------------------------------------------
diff --git 
a/llap-server/src/test/org/apache/hadoop/hive/llap/daemon/impl/TestTaskExecutorService.java
 
b/llap-server/src/test/org/apache/hadoop/hive/llap/daemon/impl/TestTaskExecutorService.java
index dd5b457..eff4abe 100644
--- 
a/llap-server/src/test/org/apache/hadoop/hive/llap/daemon/impl/TestTaskExecutorService.java
+++ 
b/llap-server/src/test/org/apache/hadoop/hive/llap/daemon/impl/TestTaskExecutorService.java
@@ -19,10 +19,18 @@ package org.apache.hadoop.hive.llap.daemon.impl;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.mock;
 
 import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.PriorityBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.hive.llap.daemon.FragmentCompletionHandler;
@@ -32,7 +40,9 @@ import 
org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos;
 import 
org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos.EntityDescriptorProto;
 import 
org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos.FragmentSpecProto;
 import 
org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos.SubmitWorkRequestProto;
+import org.apache.hadoop.hive.llap.metrics.LlapDaemonExecutorMetrics;
 import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.util.StringUtils;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
@@ -43,84 +53,26 @@ import org.apache.tez.runtime.task.EndReason;
 import org.apache.tez.runtime.task.TaskRunner2Result;
 import org.junit.Before;
 import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 public class TestTaskExecutorService {
   private static Configuration conf;
   private static Credentials cred = new Credentials();
-
-  private static class MockRequest extends TaskRunnerCallable {
-    private int workTime;
-    private boolean canFinish;
-
-    public MockRequest(LlapDaemonProtocolProtos.SubmitWorkRequestProto 
requestProto,
-        boolean canFinish, int workTime) {
-      super(requestProto, mock(QueryFragmentInfo.class), conf,
-          new ExecutionContextImpl("localhost"), null, cred, 0, null, null, 
null,
-          mock(KilledTaskHandler.class), mock(
-          FragmentCompletionHandler.class));
-      this.workTime = workTime;
-      this.canFinish = canFinish;
-    }
-
-    @Override
-    protected TaskRunner2Result callInternal() {
-      System.out.println(super.getRequestId() + " is executing..");
-      try {
-        Thread.sleep(workTime);
-      } catch (InterruptedException e) {
-        return new TaskRunner2Result(EndReason.KILL_REQUESTED, null, false);
-      }
-      return new TaskRunner2Result(EndReason.SUCCESS, null, false);
-    }
-
-    @Override
-    public boolean canFinish() {
-      return canFinish;
-    }
-  }
+  private static final Logger LOG = 
LoggerFactory.getLogger(TestTaskExecutorService.class);
 
   @Before
   public void setup() {
     conf = new Configuration();
   }
 
-  private SubmitWorkRequestProto createRequest(int fragmentNumber, int 
parallelism, int attemptStartTime) {
-    ApplicationId appId = ApplicationId.newInstance(9999, 72);
-    TezDAGID dagId = TezDAGID.getInstance(appId, 1);
-    TezVertexID vId = TezVertexID.getInstance(dagId, 35);
-    TezTaskID tId = TezTaskID.getInstance(vId, 389);
-    TezTaskAttemptID taId = TezTaskAttemptID.getInstance(tId, fragmentNumber);
-    return SubmitWorkRequestProto
-        .newBuilder()
-        .setFragmentSpec(
-            FragmentSpecProto
-                .newBuilder()
-                .setAttemptNumber(0)
-                .setDagName("MockDag")
-                .setFragmentNumber(fragmentNumber)
-                .setVertexName("MockVertex")
-                .setVertexParallelism(parallelism)
-                .setProcessorDescriptor(
-                    
EntityDescriptorProto.newBuilder().setClassName("MockProcessor").build())
-                
.setFragmentIdentifierString(taId.toString()).build()).setAmHost("localhost")
-        
.setAmPort(12345).setAppAttemptNumber(0).setApplicationIdString("MockApp_1")
-        .setContainerIdString("MockContainer_1").setUser("MockUser")
-        .setTokenIdentifier("MockToken_1")
-        .setFragmentRuntimeInfo(LlapDaemonProtocolProtos
-            .FragmentRuntimeInfo
-            .newBuilder()
-            .setFirstAttemptStartTime(attemptStartTime)
-            .build())
-        .build();
-  }
-
-  @Test
+  @Test(timeout = 5000)
   public void testWaitQueueComparator() throws InterruptedException {
-    TaskWrapper r1 = createTaskWrapper(createRequest(1, 2, 100), false, 
100000);
-    TaskWrapper r2 = createTaskWrapper(createRequest(2, 4, 200), false, 
100000);
-    TaskWrapper r3 = createTaskWrapper(createRequest(3, 6, 300), false, 
1000000);
-    TaskWrapper r4 = createTaskWrapper(createRequest(4, 8, 400), false, 
1000000);
-    TaskWrapper r5 = createTaskWrapper(createRequest(5, 10, 500), false, 
1000000);
+    TaskWrapper r1 = createTaskWrapper(createSubmitWorkRequestProto(1, 2, 
100), false, 100000);
+    TaskWrapper r2 = createTaskWrapper(createSubmitWorkRequestProto(2, 4, 
200), false, 100000);
+    TaskWrapper r3 = createTaskWrapper(createSubmitWorkRequestProto(3, 6, 
300), false, 1000000);
+    TaskWrapper r4 = createTaskWrapper(createSubmitWorkRequestProto(4, 8, 
400), false, 1000000);
+    TaskWrapper r5 = createTaskWrapper(createSubmitWorkRequestProto(5, 10, 
500), false, 1000000);
     EvictingPriorityBlockingQueue<TaskWrapper> queue = new 
EvictingPriorityBlockingQueue<>(
         new TaskExecutorService.ShortestJobFirstComparator(), 4);
     assertNull(queue.offer(r1));
@@ -138,11 +90,11 @@ public class TestTaskExecutorService {
     assertEquals(r3, queue.take());
     assertEquals(r4, queue.take());
 
-    r1 = createTaskWrapper(createRequest(1, 2, 100), true, 100000);
-    r2 = createTaskWrapper(createRequest(2, 4, 200), true, 100000);
-    r3 = createTaskWrapper(createRequest(3, 6, 300), true, 1000000);
-    r4 = createTaskWrapper(createRequest(4, 8, 400), true, 1000000);
-    r5 = createTaskWrapper(createRequest(5, 10, 500), true, 1000000);
+    r1 = createTaskWrapper(createSubmitWorkRequestProto(1, 2, 100), true, 
100000);
+    r2 = createTaskWrapper(createSubmitWorkRequestProto(2, 4, 200), true, 
100000);
+    r3 = createTaskWrapper(createSubmitWorkRequestProto(3, 6, 300), true, 
1000000);
+    r4 = createTaskWrapper(createSubmitWorkRequestProto(4, 8, 400), true, 
1000000);
+    r5 = createTaskWrapper(createSubmitWorkRequestProto(5, 10, 500), true, 
1000000);
     queue = new EvictingPriorityBlockingQueue(
         new TaskExecutorService.ShortestJobFirstComparator(), 4);
     assertNull(queue.offer(r1));
@@ -160,11 +112,11 @@ public class TestTaskExecutorService {
     assertEquals(r3, queue.take());
     assertEquals(r4, queue.take());
 
-    r1 = createTaskWrapper(createRequest(1, 1, 100), true, 100000);
-    r2 = createTaskWrapper(createRequest(2, 1, 200), false, 100000);
-    r3 = createTaskWrapper(createRequest(3, 1, 300), true, 1000000);
-    r4 = createTaskWrapper(createRequest(4, 1, 400), false, 1000000);
-    r5 = createTaskWrapper(createRequest(5, 10, 500), true, 1000000);
+    r1 = createTaskWrapper(createSubmitWorkRequestProto(1, 1, 100), true, 
100000);
+    r2 = createTaskWrapper(createSubmitWorkRequestProto(2, 1, 200), false, 
100000);
+    r3 = createTaskWrapper(createSubmitWorkRequestProto(3, 1, 300), true, 
1000000);
+    r4 = createTaskWrapper(createSubmitWorkRequestProto(4, 1, 400), false, 
1000000);
+    r5 = createTaskWrapper(createSubmitWorkRequestProto(5, 10, 500), true, 
1000000);
     queue = new EvictingPriorityBlockingQueue(
         new TaskExecutorService.ShortestJobFirstComparator(), 4);
     assertNull(queue.offer(r1));
@@ -182,11 +134,11 @@ public class TestTaskExecutorService {
     assertEquals(r5, queue.take());
     assertEquals(r2, queue.take());
 
-    r1 = createTaskWrapper(createRequest(1, 2, 100), true, 100000);
-    r2 = createTaskWrapper(createRequest(2, 4, 200), false, 100000);
-    r3 = createTaskWrapper(createRequest(3, 6, 300), true, 1000000);
-    r4 = createTaskWrapper(createRequest(4, 8, 400), false, 1000000);
-    r5 = createTaskWrapper(createRequest(5, 10, 500), true, 1000000);
+    r1 = createTaskWrapper(createSubmitWorkRequestProto(1, 2, 100), true, 
100000);
+    r2 = createTaskWrapper(createSubmitWorkRequestProto(2, 4, 200), false, 
100000);
+    r3 = createTaskWrapper(createSubmitWorkRequestProto(3, 6, 300), true, 
1000000);
+    r4 = createTaskWrapper(createSubmitWorkRequestProto(4, 8, 400), false, 
1000000);
+    r5 = createTaskWrapper(createSubmitWorkRequestProto(5, 10, 500), true, 
1000000);
     queue = new EvictingPriorityBlockingQueue(
         new TaskExecutorService.ShortestJobFirstComparator(), 4);
     assertNull(queue.offer(r1));
@@ -204,11 +156,11 @@ public class TestTaskExecutorService {
     assertEquals(r5, queue.take());
     assertEquals(r2, queue.take());
 
-    r1 = createTaskWrapper(createRequest(1, 2, 100), true, 100000);
-    r2 = createTaskWrapper(createRequest(2, 4, 200), false, 100000);
-    r3 = createTaskWrapper(createRequest(3, 6, 300), false, 1000000);
-    r4 = createTaskWrapper(createRequest(4, 8, 400), false, 1000000);
-    r5 = createTaskWrapper(createRequest(5, 10, 500), true, 1000000);
+    r1 = createTaskWrapper(createSubmitWorkRequestProto(1, 2, 100), true, 
100000);
+    r2 = createTaskWrapper(createSubmitWorkRequestProto(2, 4, 200), false, 
100000);
+    r3 = createTaskWrapper(createSubmitWorkRequestProto(3, 6, 300), false, 
1000000);
+    r4 = createTaskWrapper(createSubmitWorkRequestProto(4, 8, 400), false, 
1000000);
+    r5 = createTaskWrapper(createSubmitWorkRequestProto(5, 10, 500), true, 
1000000);
     queue = new EvictingPriorityBlockingQueue(
         new TaskExecutorService.ShortestJobFirstComparator(), 4);
     assertNull(queue.offer(r1));
@@ -226,11 +178,11 @@ public class TestTaskExecutorService {
     assertEquals(r2, queue.take());
     assertEquals(r3, queue.take());
 
-    r1 = createTaskWrapper(createRequest(1, 2, 100), false, 100000);
-    r2 = createTaskWrapper(createRequest(2, 4, 200), true, 100000);
-    r3 = createTaskWrapper(createRequest(3, 6, 300), true, 1000000);
-    r4 = createTaskWrapper(createRequest(4, 8, 400), true, 1000000);
-    r5 = createTaskWrapper(createRequest(5, 10, 500), true, 1000000);
+    r1 = createTaskWrapper(createSubmitWorkRequestProto(1, 2, 100), false, 
100000);
+    r2 = createTaskWrapper(createSubmitWorkRequestProto(2, 4, 200), true, 
100000);
+    r3 = createTaskWrapper(createSubmitWorkRequestProto(3, 6, 300), true, 
1000000);
+    r4 = createTaskWrapper(createSubmitWorkRequestProto(4, 8, 400), true, 
1000000);
+    r5 = createTaskWrapper(createSubmitWorkRequestProto(5, 10, 500), true, 
1000000);
     queue = new EvictingPriorityBlockingQueue(
         new TaskExecutorService.ShortestJobFirstComparator(), 4);
     assertNull(queue.offer(r1));
@@ -249,12 +201,12 @@ public class TestTaskExecutorService {
     assertEquals(r5, queue.take());
   }
 
-  @Test
+  @Test(timeout = 5000)
   public void testPreemptionQueueComparator() throws InterruptedException {
-    TaskWrapper r1 = createTaskWrapper(createRequest(1, 2, 100), false, 
100000);
-    TaskWrapper r2 = createTaskWrapper(createRequest(2, 4, 200), false, 
100000);
-    TaskWrapper r3 = createTaskWrapper(createRequest(3, 6, 300), false, 
1000000);
-    TaskWrapper r4 = createTaskWrapper(createRequest(4, 8, 400), false, 
1000000);
+    TaskWrapper r1 = createTaskWrapper(createSubmitWorkRequestProto(1, 2, 
100), false, 100000);
+    TaskWrapper r2 = createTaskWrapper(createSubmitWorkRequestProto(2, 4, 
200), false, 100000);
+    TaskWrapper r3 = createTaskWrapper(createSubmitWorkRequestProto(3, 6, 
300), false, 1000000);
+    TaskWrapper r4 = createTaskWrapper(createSubmitWorkRequestProto(4, 8, 
400), false, 1000000);
     BlockingQueue<TaskWrapper> queue = new PriorityBlockingQueue<>(4,
         new TaskExecutorService.PreemptionQueueComparator());
 
@@ -271,9 +223,282 @@ public class TestTaskExecutorService {
     assertEquals(r4, queue.take());
   }
 
+  @Test(timeout = 10000)
+  public void testFinishablePreeptsNonFinishable() throws InterruptedException 
{
+    MockRequest r1 = createMockRequest(1, 1, 100, false, 5000l);
+    MockRequest r2 = createMockRequest(2, 1, 100, true, 1000l);
+    TaskExecutorServiceForTest taskExecutorService = new 
TaskExecutorServiceForTest(1, 2, false, true);
+    taskExecutorService.init(conf);
+    taskExecutorService.start();
+
+    try {
+      taskExecutorService.schedule(r1);
+      r1.awaitStart();
+      taskExecutorService.schedule(r2);
+      r2.awaitStart();
+      // Verify r1 was preempted. Also verify that it finished (single 
executor), otherwise
+      // r2 could have run anyway.
+      assertTrue(r1.wasPreempted());
+      assertTrue(r1.hasFinished());
+
+      r2.complete();
+      r2.awaitEnd();
+
+      TaskExecutorServiceForTest.InternalCompletionListenerForTest icl1 =
+          
taskExecutorService.getInternalCompletionListenerForTest(r1.getRequestId());
+      TaskExecutorServiceForTest.InternalCompletionListenerForTest icl2 =
+          
taskExecutorService.getInternalCompletionListenerForTest(r2.getRequestId());
+
+      // Ensure Data structures are updated in the main TaskScheduler
+      icl1.awaitCompletion();
+      icl2.awaitCompletion();
+
+      assertEquals(0, taskExecutorService.knownTasks.size());
+    } finally {
+      taskExecutorService.shutDown(false);
+    }
+  }
+
+
+  // ----------- Helper classes and methods go after this point. Tests above 
this -----------
+
+  private SubmitWorkRequestProto createSubmitWorkRequestProto(int 
fragmentNumber, int parallelism,
+                                                              long 
attemptStartTime) {
+    ApplicationId appId = ApplicationId.newInstance(9999, 72);
+    TezDAGID dagId = TezDAGID.getInstance(appId, 1);
+    TezVertexID vId = TezVertexID.getInstance(dagId, 35);
+    TezTaskID tId = TezTaskID.getInstance(vId, 389);
+    TezTaskAttemptID taId = TezTaskAttemptID.getInstance(tId, fragmentNumber);
+    return SubmitWorkRequestProto
+        .newBuilder()
+        .setFragmentSpec(
+            FragmentSpecProto
+                .newBuilder()
+                .setAttemptNumber(0)
+                .setDagName("MockDag")
+                .setFragmentNumber(fragmentNumber)
+                .setVertexName("MockVertex")
+                .setVertexParallelism(parallelism)
+                .setProcessorDescriptor(
+                    
EntityDescriptorProto.newBuilder().setClassName("MockProcessor").build())
+                
.setFragmentIdentifierString(taId.toString()).build()).setAmHost("localhost")
+        
.setAmPort(12345).setAppAttemptNumber(0).setApplicationIdString("MockApp_1")
+        .setContainerIdString("MockContainer_1").setUser("MockUser")
+        .setTokenIdentifier("MockToken_1")
+        .setFragmentRuntimeInfo(LlapDaemonProtocolProtos
+            .FragmentRuntimeInfo
+            .newBuilder()
+            .setFirstAttemptStartTime(attemptStartTime)
+            .build())
+        .build();
+  }
+
+  private MockRequest createMockRequest(int fragmentNum, int parallelism, long 
startTime,
+                                        boolean canFinish, long workTime) {
+    SubmitWorkRequestProto requestProto = 
createSubmitWorkRequestProto(fragmentNum, parallelism,
+        startTime);
+    MockRequest mockRequest = new MockRequest(requestProto, canFinish, 
workTime);
+    return mockRequest;
+  }
+
   private TaskWrapper createTaskWrapper(SubmitWorkRequestProto request, 
boolean canFinish, int workTime) {
     MockRequest mockRequest = new MockRequest(request, canFinish, workTime);
     TaskWrapper taskWrapper = new TaskWrapper(mockRequest, null);
     return taskWrapper;
   }
+
+  private static void logInfo(String message, Throwable t) {
+    LOG.info(message, t);
+  }
+
+  private static void logInfo(String message) {
+    logInfo(message, null);
+  }
+
+  private static class MockRequest extends TaskRunnerCallable {
+    private final long workTime;
+    private final boolean canFinish;
+
+    private final AtomicBoolean isStarted = new AtomicBoolean(false);
+    private final AtomicBoolean isFinished = new AtomicBoolean(false);
+    private final AtomicBoolean wasKilled = new AtomicBoolean(false);
+    private final AtomicBoolean wasInterrupted = new AtomicBoolean(false);
+
+    private final ReentrantLock lock = new ReentrantLock();
+    private final Condition startedCondition = lock.newCondition();
+    private final Condition sleepCondition = lock.newCondition();
+    private final Condition finishedCondition = lock.newCondition();
+
+    public MockRequest(LlapDaemonProtocolProtos.SubmitWorkRequestProto 
requestProto,
+                       boolean canFinish, long workTime) {
+      super(requestProto, mock(QueryFragmentInfo.class), conf,
+          new ExecutionContextImpl("localhost"), null, cred, 0, null, null, 
mock(
+              LlapDaemonExecutorMetrics.class),
+          mock(KilledTaskHandler.class), mock(
+              FragmentCompletionHandler.class));
+      this.workTime = workTime;
+      this.canFinish = canFinish;
+    }
+
+    @Override
+    protected TaskRunner2Result callInternal() {
+      try {
+        logInfo(super.getRequestId() + " is executing..", null);
+        lock.lock();
+        try {
+          isStarted.set(true);
+          startedCondition.signal();
+        } finally {
+          lock.unlock();
+        }
+
+        lock.lock();
+        try {
+          sleepCondition.await(workTime, TimeUnit.MILLISECONDS);
+        } catch (InterruptedException e) {
+          wasInterrupted.set(true);
+          return new TaskRunner2Result(EndReason.KILL_REQUESTED, null, false);
+        } finally {
+          lock.unlock();
+        }
+        if (wasKilled.get()) {
+          return new TaskRunner2Result(EndReason.KILL_REQUESTED, null, false);
+        } else {
+          return new TaskRunner2Result(EndReason.SUCCESS, null, false);
+        }
+      } finally {
+        lock.lock();
+        try {
+          isFinished.set(true);
+          finishedCondition.signal();
+        } finally {
+          lock.unlock();
+        }
+      }
+    }
+
+    @Override
+    public void killTask() {
+      lock.lock();
+      try {
+        wasKilled.set(true);
+        sleepCondition.signal();
+      } finally {
+        lock.unlock();
+      }
+    }
+
+    boolean hasStarted() {
+      return isStarted.get();
+    }
+
+    boolean hasFinished() {
+      return isFinished.get();
+    }
+
+    boolean wasPreempted() {
+      return wasKilled.get();
+    }
+
+    void complete() {
+      lock.lock();
+      try {
+        sleepCondition.signal();
+      } finally {
+        lock.unlock();
+      }
+    }
+
+    void awaitStart() throws InterruptedException {
+      lock.lock();
+      try {
+        while (!isStarted.get()) {
+          startedCondition.await();
+        }
+      } finally {
+        lock.unlock();
+      }
+    }
+
+    void awaitEnd() throws InterruptedException {
+      lock.lock();
+      try {
+        while (!isFinished.get()) {
+          finishedCondition.await();
+        }
+      } finally {
+        lock.unlock();
+      }
+    }
+
+
+    @Override
+    public boolean canFinish() {
+      return canFinish;
+    }
+  }
+
+  private static class TaskExecutorServiceForTest extends TaskExecutorService {
+    public TaskExecutorServiceForTest(int numExecutors, int waitQueueSize, 
boolean useFairOrdering,
+                                      boolean enablePreemption) {
+      super(numExecutors, waitQueueSize, useFairOrdering, enablePreemption);
+    }
+
+    private ConcurrentMap<String, InternalCompletionListenerForTest> 
completionListeners = new ConcurrentHashMap<>();
+
+    InternalCompletionListener createInternalCompletionListener(TaskWrapper 
taskWrapper) {
+      InternalCompletionListenerForTest icl = new 
InternalCompletionListenerForTest(taskWrapper);
+      completionListeners.put(taskWrapper.getRequestId(), icl);
+      return icl;
+    }
+
+    InternalCompletionListenerForTest 
getInternalCompletionListenerForTest(String requestId) {
+      return completionListeners.get(requestId);
+    }
+
+
+    private class InternalCompletionListenerForTest extends 
TaskExecutorService.InternalCompletionListener {
+
+      private final Lock lock = new ReentrantLock();
+      private final Condition completionCondition = lock.newCondition();
+      private final AtomicBoolean isComplete = new AtomicBoolean(false);
+
+      public InternalCompletionListenerForTest(TaskWrapper taskWrapper) {
+        super(taskWrapper);
+      }
+
+      @Override
+      public void onSuccess(TaskRunner2Result result) {
+        super.onSuccess(result);
+        markComplete();
+      }
+
+      @Override
+      public void onFailure(Throwable t) {
+        super.onFailure(t);
+        markComplete();
+      }
+
+      private void markComplete() {
+        lock.lock();
+        try {
+          isComplete.set(true);
+          completionCondition.signal();
+        } finally {
+          lock.unlock();
+        }
+      }
+
+      private void awaitCompletion() throws InterruptedException {
+        lock.lock();
+        try {
+          while (!isComplete.get()) {
+            completionCondition.await();
+          }
+        } finally {
+          lock.unlock();
+        }
+      }
+    }
+  }
 }

Reply via email to