Repository: tez
Updated Branches:
  refs/heads/master bb40cf5b8 -> c34e46c73


TEZ-3897. Tez Local Mode hang for vertices with broadcast input. (Jonathan 
Eagles via jlowe)


Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/c34e46c7
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/c34e46c7
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/c34e46c7

Branch: refs/heads/master
Commit: c34e46c73218bf21a0219f3004e20cbedaad92f4
Parents: bb40cf5
Author: Jason Lowe <jl...@apache.org>
Authored: Mon Mar 5 09:53:11 2018 -0600
Committer: Jason Lowe <jl...@apache.org>
Committed: Mon Mar 5 09:53:11 2018 -0600

----------------------------------------------------------------------
 .../app/launcher/LocalContainerLauncher.java    |  19 +-
 .../dag/app/rm/LocalTaskSchedulerService.java   | 185 ++++++++++++++-----
 .../tez/dag/app/rm/TestLocalTaskScheduler.java  |   8 +-
 .../app/rm/TestLocalTaskSchedulerService.java   |  94 +++++++++-
 4 files changed, 243 insertions(+), 63 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tez/blob/c34e46c7/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java
----------------------------------------------------------------------
diff --git 
a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java
 
b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java
index 9764daa..13e4115 100644
--- 
a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java
+++ 
b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java
@@ -94,9 +94,9 @@ public class LocalContainerLauncher extends 
DagContainerLauncher {
   int shufflePort = TezRuntimeUtils.INVALID_PORT;
   private DeletionTracker deletionTracker;
 
-  private final ConcurrentHashMap<ContainerId, RunningTaskCallback>
+  private final ConcurrentHashMap<ContainerId, ListenableFuture<?>>
       runningContainers =
-      new ConcurrentHashMap<ContainerId, RunningTaskCallback>();
+      new ConcurrentHashMap<>();
 
   private final ConcurrentHashMap<ContainerId, TezLocalCacheManager>
           cacheManagers = new ConcurrentHashMap<>();
@@ -281,7 +281,7 @@ public class LocalContainerLauncher extends 
DagContainerLauncher {
       ListenableFuture<TezChild.ContainerExecutionResult> runningTaskFuture =
           taskExecutorService.submit(createSubTask(tezChild, 
event.getContainerId()));
       RunningTaskCallback callback = new 
RunningTaskCallback(event.getContainerId());
-      runningContainers.put(event.getContainerId(), callback);
+      runningContainers.put(event.getContainerId(), runningTaskFuture);
       Futures.addCallback(runningTaskFuture, callback, callbackExecutor);
       if (deletionTracker != null) {
         deletionTracker.addNodeShufflePort(event.getNodeId(), shufflePort);
@@ -293,19 +293,16 @@ public class LocalContainerLauncher extends 
DagContainerLauncher {
 
   private void stop(ContainerStopRequest event) {
     // A stop_request will come in when a task completes and reports back or a 
preemption decision
-    // is made. Currently the LocalTaskScheduler does not support preemption. 
Also preemption
-    // will not work in local mode till Tez supports task preemption instead 
of container preemption.
-    RunningTaskCallback callback =
+    // is made.
+    ListenableFuture future =
         runningContainers.get(event.getContainerId());
-    if (callback == null) {
+    if (future == null) {
       LOG.info("Ignoring stop request for containerId: " + 
event.getContainerId());
     } else {
       LOG.info(
-          "Ignoring stop request for containerId {}. Relying on regular task 
shutdown for it to end",
+          "Stopping containerId: {}",
           event.getContainerId());
-      // Allow the tezChild thread to run it's course. It'll receive a 
shutdown request from the
-      // AM eventually since the task and container will be unregistered.
-      // This will need to be fixed once interrupting tasks is supported.
+      future.cancel(true);
     }
     // Send this event to maintain regular control flow. This isn't of much 
use though.
     getContext().containerStopRequested(event.getContainerId());

http://git-wip-us.apache.org/repos/asf/tez/blob/c34e46c7/tez-dag/src/main/java/org/apache/tez/dag/app/rm/LocalTaskSchedulerService.java
----------------------------------------------------------------------
diff --git 
a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/LocalTaskSchedulerService.java
 
b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/LocalTaskSchedulerService.java
index 04e79a8..cc213cb 100644
--- 
a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/LocalTaskSchedulerService.java
+++ 
b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/LocalTaskSchedulerService.java
@@ -19,6 +19,9 @@
 package org.apache.tez.dag.app.rm;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.BitSet;
+import java.util.Map;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.PriorityBlockingQueue;
@@ -29,6 +32,7 @@ import java.util.LinkedHashMap;
 import com.google.common.primitives.Ints;
 
 import org.apache.tez.common.TezUtils;
+import org.apache.tez.serviceplugins.api.DagInfo;
 import org.apache.tez.serviceplugins.api.TaskScheduler;
 import org.apache.tez.serviceplugins.api.TaskSchedulerContext;
 import org.slf4j.Logger;
@@ -51,19 +55,19 @@ public class LocalTaskSchedulerService extends 
TaskScheduler {
   private static final Logger LOG = 
LoggerFactory.getLogger(LocalTaskSchedulerService.class);
 
   final ContainerSignatureMatcher containerSignatureMatcher;
-  final LinkedBlockingQueue<TaskRequest> taskRequestQueue;
+  final LinkedBlockingQueue<SchedulerRequest> taskRequestQueue;
   final Configuration conf;
   AsyncDelegateRequestHandler taskRequestHandler;
   Thread asyncDelegateRequestThread;
 
-  final HashMap<Object, Container> taskAllocations;
+  final HashMap<Object, AllocatedTask> taskAllocations;
   final String appTrackingUrl;
   final long customContainerAppId;
 
   public LocalTaskSchedulerService(TaskSchedulerContext taskSchedulerContext) {
     super(taskSchedulerContext);
     taskRequestQueue = new LinkedBlockingQueue<>();
-    taskAllocations = new LinkedHashMap<Object, Container>();
+    taskAllocations = new LinkedHashMap<>();
     this.appTrackingUrl = taskSchedulerContext.getAppTrackingUrl();
     this.containerSignatureMatcher = 
taskSchedulerContext.getContainerSignatureMatcher();
     this.customContainerAppId = 
taskSchedulerContext.getCustomClusterIdentifier();
@@ -98,6 +102,7 @@ public class LocalTaskSchedulerService extends TaskScheduler 
{
 
   @Override
   public void dagComplete() {
+    taskRequestHandler.dagComplete();
   }
 
   @Override
@@ -129,7 +134,7 @@ public class LocalTaskSchedulerService extends 
TaskScheduler {
     // in local mode every task is already container level local
     taskRequestHandler.addAllocateTaskRequest(task, capability, priority, 
clientCookie);
   }
-  
+
   @Override
   public boolean deallocateTask(Object task, boolean taskSucceeded, 
TaskAttemptEndReason endReason, String diagnostics) {
     return taskRequestHandler.addDeallocateTaskRequest(task);
@@ -137,6 +142,7 @@ public class LocalTaskSchedulerService extends 
TaskScheduler {
 
   @Override
   public Object deallocateContainer(ContainerId containerId) {
+    taskRequestHandler.addDeallocateContainerRequest(containerId);
     return null;
   }
 
@@ -212,20 +218,14 @@ public class LocalTaskSchedulerService extends 
TaskScheduler {
     }
   }
 
-  static class TaskRequest implements Comparable<TaskRequest> {
-    // Higher prority than Priority.UNDEFINED
-    static final int HIGHEST_PRIORITY = -2;
-    Object task;
-    Priority priority;
+  static class SchedulerRequest {
+  }
 
-    public TaskRequest(Object task, Priority priority) {
-      this.task = task;
-      this.priority = priority;
-    }
+  static class TaskRequest extends SchedulerRequest {
+    final Object task;
 
-    @Override
-    public int compareTo(TaskRequest request) {
-      return request.priority.compareTo(this.priority);
+    public TaskRequest(Object task) {
+      this.task = task;
     }
 
     @Override
@@ -239,9 +239,6 @@ public class LocalTaskSchedulerService extends 
TaskScheduler {
 
       TaskRequest that = (TaskRequest) o;
 
-      if (priority != null ? !priority.equals(that.priority) : that.priority 
!= null) {
-        return false;
-      }
       if (task != null ? !task.equals(that.task) : that.task != null) {
         return false;
       }
@@ -251,23 +248,29 @@ public class LocalTaskSchedulerService extends 
TaskScheduler {
 
     @Override
     public int hashCode() {
-      int result = 1;
-      result = 7841 * result + (task != null ? task.hashCode() : 0);
-      result = 7841 * result + (priority != null ? priority.hashCode() : 0);
-      return result;
+      return 7841 + (task != null ? task.hashCode() : 0);
     }
 
   }
 
-  static class AllocateTaskRequest extends TaskRequest {
-    Resource capability;
-    Object clientCookie;
+  static class AllocateTaskRequest extends TaskRequest implements 
Comparable<AllocateTaskRequest> {
+    final Priority priority;
+    final Resource capability;
+    final Object clientCookie;
+    final int vertexIndex;
 
-    public AllocateTaskRequest(Object task, Resource capability, Priority 
priority,
-        Object clientCookie) {
-      super(task, priority);
+    public AllocateTaskRequest(Object task, int vertexIndex, Resource 
capability, Priority priority,
+                               Object clientCookie) {
+      super(task);
+      this.priority = priority;
       this.capability = capability;
       this.clientCookie = clientCookie;
+      this.vertexIndex = vertexIndex;
+    }
+
+    @Override
+    public int compareTo(AllocateTaskRequest request) {
+      return request.priority.compareTo(this.priority);
     }
 
     @Override
@@ -284,6 +287,10 @@ public class LocalTaskSchedulerService extends 
TaskScheduler {
 
       AllocateTaskRequest that = (AllocateTaskRequest) o;
 
+      if (priority != null ? !priority.equals(that.priority) : that.priority 
!= null) {
+        return false;
+      }
+
       if (capability != null ? !capability.equals(that.capability) : 
that.capability != null) {
         return false;
       }
@@ -298,6 +305,7 @@ public class LocalTaskSchedulerService extends 
TaskScheduler {
     @Override
     public int hashCode() {
       int result = super.hashCode();
+      result = 12329 * result + (priority != null ? priority.hashCode() : 0);
       result = 12329 * result + (capability != null ? capability.hashCode() : 
0);
       result = 12329 * result + (clientCookie != null ? 
clientCookie.hashCode() : 0);
       return result;
@@ -305,24 +313,43 @@ public class LocalTaskSchedulerService extends 
TaskScheduler {
   }
 
   static class DeallocateTaskRequest extends TaskRequest {
-    static final Priority DEALLOCATE_PRIORITY = 
Priority.newInstance(HIGHEST_PRIORITY);
 
     public DeallocateTaskRequest(Object task) {
-      super(task, DEALLOCATE_PRIORITY);
+      super(task);
+    }
+  }
+
+  static class DeallocateContainerRequest extends SchedulerRequest {
+    final ContainerId containerId;
+
+    public DeallocateContainerRequest(ContainerId containerId) {
+      this.containerId = containerId;
+    }
+  }
+
+  static class AllocatedTask {
+    final AllocateTaskRequest request;
+    final Container container;
+
+    AllocatedTask(AllocateTaskRequest request, Container container) {
+      this.request = request;
+      this.container = container;
     }
   }
 
   static class AsyncDelegateRequestHandler implements Runnable {
-    final LinkedBlockingQueue<TaskRequest> clientRequestQueue;
+    final LinkedBlockingQueue<SchedulerRequest> clientRequestQueue;
     final PriorityBlockingQueue<AllocateTaskRequest> taskRequestQueue;
     final LocalContainerFactory localContainerFactory;
-    final HashMap<Object, Container> taskAllocations;
+    final HashMap<Object, AllocatedTask> taskAllocations;
     final TaskSchedulerContext taskSchedulerContext;
+    private final Object descendantsLock = new Object();
+    private ArrayList<BitSet> vertexDescendants = null;
     final int MAX_TASKS;
 
-    AsyncDelegateRequestHandler(LinkedBlockingQueue<TaskRequest> 
clientRequestQueue,
+    AsyncDelegateRequestHandler(LinkedBlockingQueue<SchedulerRequest> 
clientRequestQueue,
         LocalContainerFactory localContainerFactory,
-        HashMap<Object, Container> taskAllocations,
+        HashMap<Object, AllocatedTask> taskAllocations,
         TaskSchedulerContext taskSchedulerContext,
         Configuration conf) {
       this.clientRequestQueue = clientRequestQueue;
@@ -334,10 +361,33 @@ public class LocalTaskSchedulerService extends 
TaskScheduler {
       this.taskRequestQueue = new PriorityBlockingQueue<>();
     }
 
+    void dagComplete() {
+      synchronized (descendantsLock) {
+        vertexDescendants = null;
+      }
+    }
+    private void ensureVertexDescendants() {
+      synchronized (descendantsLock) {
+        if (vertexDescendants == null) {
+          DagInfo info = taskSchedulerContext.getCurrentDagInfo();
+          if (info == null) {
+            throw new IllegalStateException("Scheduling tasks but no current 
DAG info?");
+          }
+          int numVertices = info.getTotalVertices();
+          ArrayList<BitSet> descendants = new ArrayList<>(numVertices);
+          for (int i = 0; i < numVertices; ++i) {
+            descendants.add(info.getVertexDescendants(i));
+          }
+          vertexDescendants = descendants;
+        }
+      }
+    }
+
     public void addAllocateTaskRequest(Object task, Resource capability, 
Priority priority,
         Object clientCookie) {
       try {
-        clientRequestQueue.put(new AllocateTaskRequest(task, capability, 
priority, clientCookie));
+        int vertexIndex = taskSchedulerContext.getVertexIndexForTask(task);
+        clientRequestQueue.put(new AllocateTaskRequest(task, vertexIndex, 
capability, priority, clientCookie));
       } catch (InterruptedException e) {
         Thread.currentThread().interrupt();
       }
@@ -352,10 +402,22 @@ public class LocalTaskSchedulerService extends 
TaskScheduler {
       return true;
     }
 
+    public void addDeallocateContainerRequest(ContainerId containerId) {
+      try {
+        clientRequestQueue.put(new DeallocateContainerRequest(containerId));
+      } catch (InterruptedException e) {
+        Thread.currentThread().interrupt();
+      }
+    }
+
     boolean shouldProcess() {
       return !taskRequestQueue.isEmpty() && taskAllocations.size() < MAX_TASKS;
     }
 
+    boolean shouldPreempt() {
+      return !taskRequestQueue.isEmpty() && taskAllocations.size() >= 
MAX_TASKS;
+    }
+
     @Override
     public void run() {
       while (!Thread.currentThread().isInterrupted()) {
@@ -368,13 +430,19 @@ public class LocalTaskSchedulerService extends 
TaskScheduler {
 
     void dispatchRequest() {
       try {
-        TaskRequest request = clientRequestQueue.take();
+        SchedulerRequest request = clientRequestQueue.take();
         if (request instanceof AllocateTaskRequest) {
           taskRequestQueue.put((AllocateTaskRequest)request);
+          if (shouldPreempt()) {
+            maybePreempt((AllocateTaskRequest) request);
+          }
         }
         else if (request instanceof DeallocateTaskRequest) {
           deallocateTask((DeallocateTaskRequest)request);
         }
+        else if (request instanceof DeallocateContainerRequest) {
+          preemptTask((DeallocateContainerRequest)request);
+        }
         else {
           LOG.error("Unknown task request message: " + request);
         }
@@ -383,12 +451,29 @@ public class LocalTaskSchedulerService extends 
TaskScheduler {
       }
     }
 
+    void maybePreempt(AllocateTaskRequest request) {
+      Priority priority = request.priority;
+      for (Map.Entry<Object, AllocatedTask> entry : 
taskAllocations.entrySet()) {
+        AllocatedTask allocatedTask = entry.getValue();
+        Container container = allocatedTask.container;
+        if (priority.compareTo(allocatedTask.container.getPriority()) > 0) {
+          Object task = entry.getKey();
+          ensureVertexDescendants();
+          if 
(vertexDescendants.get(request.vertexIndex).get(allocatedTask.request.vertexIndex))
 {
+            LOG.info("Preempting task/container for task/priority:"  + task + 
"/" + container
+                + " for " + request.task + "/" + priority);
+            
taskSchedulerContext.preemptContainer(allocatedTask.container.getId());
+          }
+        }
+      }
+    }
+
     void allocateTask() {
       try {
         AllocateTaskRequest request = taskRequestQueue.take();
         Container container = 
localContainerFactory.createContainer(request.capability,
             request.priority);
-        taskAllocations.put(request.task, container);
+        taskAllocations.put(request.task, new AllocatedTask(request, 
container));
         taskSchedulerContext.taskAllocated(request.task, request.clientCookie, 
container);
       } catch (InterruptedException e) {
         Thread.currentThread().interrupt();
@@ -396,24 +481,34 @@ public class LocalTaskSchedulerService extends 
TaskScheduler {
     }
 
     void deallocateTask(DeallocateTaskRequest request) {
-      Container container = taskAllocations.remove(request.task);
-      if (container != null) {
-        taskSchedulerContext.containerBeingReleased(container.getId());
+      AllocatedTask allocatedTask = taskAllocations.remove(request.task);
+      if (allocatedTask != null) {
+        
taskSchedulerContext.containerBeingReleased(allocatedTask.container.getId());
       }
       else {
-        boolean deallocationBeforeAllocation = false;
         Iterator<AllocateTaskRequest> iter = taskRequestQueue.iterator();
         while (iter.hasNext()) {
           TaskRequest taskRequest = iter.next();
           if (taskRequest.task.equals(request.task)) {
             iter.remove();
-            deallocationBeforeAllocation = true;
             LOG.info("Deallocation request before allocation for task:" + 
request.task);
             break;
           }
         }
-        if (!deallocationBeforeAllocation) {
-          throw new TezUncheckedException("Unable to find and remove task " + 
request.task + " from task allocations");
+      }
+    }
+
+    void preemptTask(DeallocateContainerRequest request) {
+      LOG.info("Trying to preempt: " + request.containerId);
+      Iterator<Map.Entry<Object, AllocatedTask>> entries = 
taskAllocations.entrySet().iterator();
+      while (entries.hasNext()) {
+        Map.Entry<Object, AllocatedTask> entry = entries.next();
+        Container container = entry.getValue().container;
+        if (container.getId().equals(request.containerId)) {
+          entries.remove();
+          Object task = entry.getKey();
+          LOG.info("Preempting task/container:" + task + "/" + container);
+          taskSchedulerContext.containerBeingReleased(container.getId());
         }
       }
     }

http://git-wip-us.apache.org/repos/asf/tez/blob/c34e46c7/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskScheduler.java
----------------------------------------------------------------------
diff --git 
a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskScheduler.java 
b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskScheduler.java
index 36505c2..d7b516a 100644
--- 
a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskScheduler.java
+++ 
b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskScheduler.java
@@ -29,13 +29,13 @@ import org.junit.Test;
 
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
-import org.apache.hadoop.yarn.api.records.Container;
 import org.apache.hadoop.yarn.api.records.Priority;
 
 import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.app.rm.LocalTaskSchedulerService.AllocatedTask;
 import 
org.apache.tez.dag.app.rm.LocalTaskSchedulerService.AsyncDelegateRequestHandler;
 import 
org.apache.tez.dag.app.rm.LocalTaskSchedulerService.LocalContainerFactory;
-import org.apache.tez.dag.app.rm.LocalTaskSchedulerService.TaskRequest;
+import org.apache.tez.dag.app.rm.LocalTaskSchedulerService.SchedulerRequest;
 
 public class TestLocalTaskScheduler {
 
@@ -56,8 +56,8 @@ public class TestLocalTaskScheduler {
 
     LocalContainerFactory containerFactory = new 
LocalContainerFactory(appAttemptId, 1000);
 
-    HashMap<Object, Container> taskAllocations = new LinkedHashMap<Object, 
Container>();
-    LinkedBlockingQueue<TaskRequest> clientRequestQueue = new 
LinkedBlockingQueue<>();
+    HashMap<Object, AllocatedTask> taskAllocations = new LinkedHashMap<>();
+    LinkedBlockingQueue<SchedulerRequest> clientRequestQueue = new 
LinkedBlockingQueue<>();
 
     // Object under test
     AsyncDelegateRequestHandler requestHandler =

http://git-wip-us.apache.org/repos/asf/tez/blob/c34e46c7/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskSchedulerService.java
----------------------------------------------------------------------
diff --git 
a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskSchedulerService.java
 
b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskSchedulerService.java
index c2daf84..70e31f3 100644
--- 
a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskSchedulerService.java
+++ 
b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskSchedulerService.java
@@ -18,20 +18,25 @@
 
 package org.apache.tez.dag.app.rm;
 
+import java.util.BitSet;
 import java.util.HashMap;
 import java.util.concurrent.LinkedBlockingQueue;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
-import org.apache.hadoop.yarn.api.records.Container;
+import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.api.records.Priority;
 import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.app.dag.Task;
 import 
org.apache.tez.dag.app.rm.TestLocalTaskSchedulerService.MockLocalTaskSchedulerSerivce.MockAsyncDelegateRequestHandler;
+import org.apache.tez.serviceplugins.api.DagInfo;
 import org.apache.tez.serviceplugins.api.TaskSchedulerContext;
 import org.junit.Assert;
 import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 
 import static org.junit.Assert.*;
 import static org.mockito.Mockito.*;
@@ -138,6 +143,82 @@ public class TestLocalTaskSchedulerService {
     taskSchedulerService.shutdown();
   }
 
+  @Test
+  public void preemptDescendantsOnly() {
+
+    final int MAX_TASKS = 2;
+    TezConfiguration tezConf = new TezConfiguration();
+    tezConf.setInt(TezConfiguration.TEZ_AM_INLINE_TASK_EXECUTION_MAX_TASKS, 
MAX_TASKS);
+
+    ApplicationId appId = ApplicationId.newInstance(2000, 1);
+    ApplicationAttemptId appAttemptId = 
ApplicationAttemptId.newInstance(appId, 1);
+    Long parentTask1 = new Long(1);
+    Long parentTask2 = new Long(2);
+    Long childTask1 = new Long(3);
+    Long grandchildTask1 = new Long(4);
+
+    TaskSchedulerContext
+        mockContext = 
TestTaskSchedulerHelpers.setupMockTaskSchedulerContext("", 0, "", true,
+        appAttemptId, 1000l, null, tezConf);
+    when(mockContext.getVertexIndexForTask(parentTask1)).thenReturn(0);
+    when(mockContext.getVertexIndexForTask(parentTask2)).thenReturn(0);
+    when(mockContext.getVertexIndexForTask(childTask1)).thenReturn(1);
+    when(mockContext.getVertexIndexForTask(grandchildTask1)).thenReturn(2);
+
+    DagInfo mockDagInfo = mock(DagInfo.class);
+    when(mockDagInfo.getTotalVertices()).thenReturn(3);
+    BitSet vertex1Descendants = new BitSet();
+    vertex1Descendants.set(1);
+    vertex1Descendants.set(2);
+    BitSet vertex2Descendants = new BitSet();
+    vertex2Descendants.set(2);
+    BitSet vertex3Descendants = new BitSet();
+    when(mockDagInfo.getVertexDescendants(0)).thenReturn(vertex1Descendants);
+    when(mockDagInfo.getVertexDescendants(1)).thenReturn(vertex2Descendants);
+    when(mockDagInfo.getVertexDescendants(2)).thenReturn(vertex3Descendants);
+    when(mockContext.getCurrentDagInfo()).thenReturn(mockDagInfo);
+
+    Priority priority1 = Priority.newInstance(1);
+    Priority priority2 = Priority.newInstance(2);
+    Priority priority3 = Priority.newInstance(3);
+    Priority priority4 = Priority.newInstance(4);
+    Resource resource = Resource.newInstance(1024, 1);
+
+    MockLocalTaskSchedulerSerivce taskSchedulerService = new 
MockLocalTaskSchedulerSerivce(mockContext);
+
+    // The mock context need to send a deallocate container request to the 
scheduler service
+    Answer<Void> answer = new Answer<Void>() {
+      @Override
+      public Void answer(InvocationOnMock invocation) {
+        ContainerId containerId = invocation.getArgumentAt(0, 
ContainerId.class);
+        taskSchedulerService.deallocateContainer(containerId);
+        return null;
+      }
+    };
+    
doAnswer(answer).when(mockContext).preemptContainer(any(ContainerId.class));
+
+    taskSchedulerService.initialize();
+    taskSchedulerService.start();
+    taskSchedulerService.startRequestHandlerThread();
+
+    MockAsyncDelegateRequestHandler requestHandler = 
taskSchedulerService.getRequestHandler();
+    taskSchedulerService.allocateTask(parentTask1, resource, null, null, 
priority1, null, null);
+    taskSchedulerService.allocateTask(childTask1, resource, null, null, 
priority3, null, null);
+    taskSchedulerService.allocateTask(grandchildTask1, resource, null, null, 
priority4, null, null);
+    requestHandler.drainRequest(3);
+
+    // We should not preempt if we have not reached max task allocations
+    Assert.assertEquals("Wrong number of allocate tasks", MAX_TASKS, 
requestHandler.allocateCount);
+    Assert.assertTrue("Another allocation should not fit", 
!requestHandler.shouldProcess());
+
+    // Next task allocation should preempt
+    taskSchedulerService.allocateTask(parentTask2, Resource.newInstance(1024, 
1), null, null, priority2, null, null);
+    requestHandler.drainRequest(5);
+
+    // All allocated tasks should have been removed
+    Assert.assertEquals("Wrong number of preempted tasks", 1, 
requestHandler.preemptCount);
+  }
+
   static class MockLocalTaskSchedulerSerivce extends LocalTaskSchedulerService 
{
 
     private MockAsyncDelegateRequestHandler requestHandler;
@@ -173,12 +254,13 @@ public class TestLocalTaskSchedulerService {
 
       public int allocateCount = 0;
       public int deallocateCount = 0;
+      public int preemptCount = 0;
       public int dispatchCount = 0;
 
       MockAsyncDelegateRequestHandler(
-          LinkedBlockingQueue<TaskRequest> taskRequestQueue,
+          LinkedBlockingQueue<SchedulerRequest> taskRequestQueue,
           LocalContainerFactory localContainerFactory,
-          HashMap<Object, Container> taskAllocations,
+          HashMap<Object, AllocatedTask> taskAllocations,
           TaskSchedulerContext appClientDelegate, Configuration conf) {
         super(taskRequestQueue, localContainerFactory, taskAllocations,
             appClientDelegate, conf);
@@ -211,6 +293,12 @@ public class TestLocalTaskSchedulerService {
         super.deallocateTask(request);
         deallocateCount++;
       }
+
+      @Override
+      void preemptTask(DeallocateContainerRequest request) {
+        super.preemptTask(request);
+        preemptCount++;
+      }
     }
   }
 }

Reply via email to