This is an automated email from the ASF dual-hosted git repository.

hashutosh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/tez.git


The following commit(s) were added to refs/heads/master by this push:
     new 2170126  TEZ-4170 : RootInputInitializerManager could make use of 
ThreadPool from appContext ( Attila Magyar via Rajesh Balamohan)
2170126 is described below

commit 21701261ebc02f4dea1752509b18144d70dc534f
Author: Attila Magyar <[email protected]>
AuthorDate: Thu Jul 2 10:41:50 2020 -0700

    TEZ-4170 : RootInputInitializerManager could make use of ThreadPool from 
appContext ( Attila Magyar via Rajesh Balamohan)
    
    Signed-off-by: Ashutosh Chauhan <[email protected]>
---
 .../dag/app/dag/RootInputInitializerManager.java   | 243 +++++++++------------
 .../apache/tez/dag/app/dag/impl/VertexImpl.java    |  25 +--
 .../app/dag/TestRootInputInitializerManager.java   |  26 ++-
 .../tez/dag/app/dag/impl/TestDAGRecovery.java      |   6 +-
 .../tez/dag/app/dag/impl/TestVertexImpl.java       |  97 +++++++-
 5 files changed, 223 insertions(+), 174 deletions(-)

diff --git 
a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/RootInputInitializerManager.java
 
b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/RootInputInitializerManager.java
index 7ff9fa9..5ce0050 100644
--- 
a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/RootInputInitializerManager.java
+++ 
b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/RootInputInitializerManager.java
@@ -18,35 +18,30 @@
 
 package org.apache.tez.dag.app.dag;
 
-import javax.annotation.Nullable;
+import static org.apache.tez.dag.app.dag.VertexState.FAILED;
 
 import java.io.IOException;
 import java.lang.reflect.UndeclaredThrowableException;
 import java.security.PrivilegedExceptionAction;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
-import java.util.concurrent.Callable;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.Objects;
 
-import org.apache.tez.common.Preconditions;
-import com.google.common.collect.LinkedListMultimap;
-import com.google.common.collect.ListMultimap;
-import com.google.common.collect.Lists;
+import javax.annotation.Nullable;
 
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import org.apache.commons.lang.exception.ExceptionUtils;
 import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.yarn.event.EventHandler;
-import org.apache.tez.common.GuavaShim;
+import org.apache.tez.common.Preconditions;
 import org.apache.tez.common.ReflectionUtils;
 import org.apache.tez.common.TezUtilsInternal;
 import org.apache.tez.dag.api.InputDescriptor;
@@ -54,38 +49,38 @@ import org.apache.tez.dag.api.InputInitializerDescriptor;
 import org.apache.tez.dag.api.RootInputLeafOutput;
 import org.apache.tez.dag.api.TezException;
 import org.apache.tez.dag.api.TezUncheckedException;
-import org.apache.tez.dag.api.event.*;
 import org.apache.tez.dag.api.event.VertexState;
+import org.apache.tez.dag.api.event.VertexStateUpdate;
 import org.apache.tez.dag.api.oldrecords.TaskState;
 import org.apache.tez.dag.app.AppContext;
 import org.apache.tez.dag.app.dag.event.VertexEventRootInputFailed;
 import org.apache.tez.dag.app.dag.event.VertexEventRootInputInitialized;
 import org.apache.tez.dag.app.dag.impl.AMUserCodeException;
-import org.apache.tez.dag.app.dag.impl.TezRootInputInitializerContextImpl;
 import org.apache.tez.dag.app.dag.impl.AMUserCodeException.Source;
+import org.apache.tez.dag.app.dag.impl.TezRootInputInitializerContextImpl;
+import org.apache.tez.dag.app.dag.impl.VertexImpl;
 import org.apache.tez.dag.records.TezTaskID;
 import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.runtime.api.Event;
 import org.apache.tez.runtime.api.InputInitializer;
 import org.apache.tez.runtime.api.InputInitializerContext;
+import org.apache.tez.runtime.api.events.InputInitializerEvent;
+import org.apache.tez.runtime.api.impl.TezEvent;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.util.concurrent.FutureCallback;
-import com.google.common.util.concurrent.Futures;
-import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.collect.LinkedListMultimap;
+import com.google.common.collect.ListMultimap;
+import com.google.common.collect.Lists;
 import com.google.common.util.concurrent.ListeningExecutorService;
-import com.google.common.util.concurrent.MoreExecutors;
-import com.google.common.util.concurrent.ThreadFactoryBuilder;
-
-import org.apache.tez.runtime.api.events.InputInitializerEvent;
-import org.apache.tez.runtime.api.impl.TezEvent;
 
 public class RootInputInitializerManager {
 
   private static final Logger LOG = 
LoggerFactory.getLogger(RootInputInitializerManager.class);
 
-  private final ExecutorService rawExecutor;
-  private final ListeningExecutorService executor;
+  @VisibleForTesting
+  protected ListeningExecutorService executor;
   @SuppressWarnings("rawtypes")
   private final EventHandler eventHandler;
   private volatile boolean isStopped = false;
@@ -96,50 +91,106 @@ public class RootInputInitializerManager {
   private final AppContext appContext;
 
   @VisibleForTesting
-  final Map<String, InitializerWrapper> initializerMap = new HashMap<String, 
InitializerWrapper>();
+  final Map<String, InitializerWrapper> initializerMap = new 
ConcurrentHashMap<>();
 
   public RootInputInitializerManager(Vertex vertex, AppContext appContext,
                                      UserGroupInformation dagUgi, 
StateChangeNotifier stateTracker) {
     this.appContext = appContext;
     this.vertex = vertex;
     this.eventHandler = appContext.getEventHandler();
-    this.rawExecutor = Executors.newCachedThreadPool(new ThreadFactoryBuilder()
-        .setDaemon(true).setNameFormat("InputInitializer {" + 
this.vertex.getName() + "} #%d").build());
-    this.executor = MoreExecutors.listeningDecorator(rawExecutor);
+    this.executor = appContext.getExecService();
     this.dagUgi = dagUgi;
     this.entityStateTracker = stateTracker;
   }
-  
-  public void runInputInitializers(List<RootInputLeafOutput<InputDescriptor, 
InputInitializerDescriptor>>
-      inputs) throws TezException {
-    for (RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor> 
input : inputs) {
 
-      InputInitializerContext context =
-          new TezRootInputInitializerContextImpl(input, vertex, appContext, 
this);
 
-      InputInitializer initializer;
+  public void runInputInitializers(
+          List<RootInputLeafOutput<InputDescriptor, 
InputInitializerDescriptor>> inputs, List<TezEvent> pendingInitializerEvents) {
+
+    executor.submit(() -> createAndStartInitializing(inputs, 
pendingInitializerEvents));
+  }
+
+  private void 
createAndStartInitializing(List<RootInputLeafOutput<InputDescriptor, 
InputInitializerDescriptor>> inputs, List<TezEvent> pendingInitializerEvents) {
+    String current = null;
+    try {
+      List<InitializerWrapper> result = new ArrayList<>();
+      for (RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor> 
each : inputs) {
+        current = each.getName();
+        InitializerWrapper initializer = createInitializerWrapper(each);
+        initializerMap.put(each.getName(), initializer);
+        registerPendingVertex(each, initializer);
+        result.add(initializer);
+      }
+      handleInitializerEvents(pendingInitializerEvents);
+      pendingInitializerEvents.clear();
+      for (InitializerWrapper inputWrapper : result) {
+        executor.submit(() -> runInitializerAndProcessResult(inputWrapper));
+      }
+    } catch (Throwable t) {
+      VertexImpl vertexImpl = (VertexImpl) vertex;
+      String msg = "Fail to create InputInitializerManager, " + 
ExceptionUtils.getStackTrace(t);
+      LOG.info(msg);
+      vertexImpl.finished(FAILED, VertexTerminationCause.INIT_FAILURE, msg);
+      eventHandler.handle(new VertexEventRootInputFailed(vertex.getVertexId(), 
current,
+              new 
AMUserCodeException(AMUserCodeException.Source.InputInitializer, t)));
+
+    }
+  }
+
+  private void runInitializerAndProcessResult(InitializerWrapper initializer) {
+    try {
+      List<Event> result = runInitializer(initializer);
+      LOG.info("Succeeded InputInitializer for Input: " + 
initializer.getInput().getName() +
+                  " on vertex " + initializer.getVertexLogIdentifier());
+      eventHandler.handle(new 
VertexEventRootInputInitialized(vertex.getVertexId(),
+          initializer.getInput().getName(), result));
+    } catch (Throwable t) {
+        if (t instanceof UndeclaredThrowableException) {
+          t = t.getCause();
+        }
+        LOG.info("Failed InputInitializer for Input: " + 
initializer.getInput().getName() +
+                    " on vertex " + initializer.getVertexLogIdentifier());
+        eventHandler.handle(new 
VertexEventRootInputFailed(vertex.getVertexId(), 
initializer.getInput().getName(),
+                    new AMUserCodeException(Source.InputInitializer,t)));
+    } finally {
+      initializer.setComplete();
+    }
+  }
+
+  private List<Event> runInitializer(InitializerWrapper initializer) throws 
IOException, InterruptedException {
+    return dagUgi.doAs((PrivilegedExceptionAction<List<Event>>) () -> {
+      LOG.info(
+              "Starting InputInitializer for Input: " + 
initializer.getInput().getName() +
+                      " on vertex " + initializer.getVertexLogIdentifier());
       try {
-        TezUtilsInternal.setHadoopCallerContext(appContext.getHadoopShim(), 
vertex.getVertexId());
-        initializer = createInitializer(input, context);
+        TezUtilsInternal.setHadoopCallerContext(appContext.getHadoopShim(),
+                initializer.vertexId);
+        return initializer.getInitializer().initialize();
       } finally {
         appContext.getHadoopShim().clearHadoopCallerContext();
       }
+    });
+  }
 
-      InitializerWrapper initializerWrapper =
-          new InitializerWrapper(input, initializer, context, vertex, 
entityStateTracker, appContext);
+  private InitializerWrapper 
createInitializerWrapper(RootInputLeafOutput<InputDescriptor, 
InputInitializerDescriptor> input) throws TezException {
+    InputInitializerContext context =
+            new TezRootInputInitializerContextImpl(input, vertex, appContext, 
this);
+    try {
+      TezUtilsInternal.setHadoopCallerContext(appContext.getHadoopShim(), 
vertex.getVertexId());
+      InputInitializer initializer = createInitializer(input, context);
+      return new InitializerWrapper(input, initializer, context, vertex, 
entityStateTracker, appContext);
+    } finally {
+      appContext.getHadoopShim().clearHadoopCallerContext();
+    }
+  }
 
-      // Register pending vertex update registrations
-      List<VertexUpdateRegistrationHolder> vertexUpdateRegistrations = 
pendingVertexRegistrations.removeAll(input.getName());
-      if (vertexUpdateRegistrations != null) {
-        for (VertexUpdateRegistrationHolder h : vertexUpdateRegistrations) {
-          initializerWrapper.registerForVertexStateUpdates(h.vertexName, 
h.stateSet);
-        }
+  private void registerPendingVertex(RootInputLeafOutput<InputDescriptor, 
InputInitializerDescriptor> input, InitializerWrapper initializerWrapper) {
+    // Register pending vertex update registrations
+    List<VertexUpdateRegistrationHolder> vertexUpdateRegistrations = 
pendingVertexRegistrations.removeAll(input.getName());
+    if (vertexUpdateRegistrations != null) {
+      for (VertexUpdateRegistrationHolder h : vertexUpdateRegistrations) {
+        initializerWrapper.registerForVertexStateUpdates(h.vertexName, 
h.stateSet);
       }
-
-      initializerMap.put(input.getName(), initializerWrapper);
-      ListenableFuture<List<Event>> future = executor
-          .submit(new InputInitializerCallable(initializerWrapper, dagUgi, 
appContext));
-      Futures.addCallback(future, 
createInputInitializerCallback(initializerWrapper), GuavaShim.directExecutor());
     }
   }
 
@@ -233,103 +284,13 @@ public class RootInputInitializerManager {
   }
 
   @VisibleForTesting
-  protected InputInitializerCallback 
createInputInitializerCallback(InitializerWrapper initializer) {
-    return new InputInitializerCallback(initializer, eventHandler, 
vertex.getVertexId());
-  }
-
-  @VisibleForTesting
   @InterfaceAudience.Private
   public InitializerWrapper getInitializerWrapper(String inputName) {
     return initializerMap.get(inputName);
   }
 
   public void shutdown() {
-    if (executor != null && !isStopped) {
-      // Don't really care about what is running if an error occurs. If no 
error
-      // occurs, all execution is complete.
-      executor.shutdownNow();
-      isStopped = true;
-    }
-  }
-
-  private static class InputInitializerCallable implements
-      Callable<List<Event>> {
-
-    private final InitializerWrapper initializerWrapper;
-    private final UserGroupInformation ugi;
-    private final AppContext appContext;
-
-    public InputInitializerCallable(InitializerWrapper initializer, 
UserGroupInformation ugi,
-                                    AppContext appContext) {
-      this.initializerWrapper = initializer;
-      this.ugi = ugi;
-      this.appContext = appContext;
-    }
-
-    @Override
-    public List<Event> call() throws Exception {
-      List<Event> events = ugi.doAs(new 
PrivilegedExceptionAction<List<Event>>() {
-        @Override
-        public List<Event> run() throws Exception {
-          LOG.info(
-              "Starting InputInitializer for Input: " + 
initializerWrapper.getInput().getName() +
-                  " on vertex " + initializerWrapper.getVertexLogIdentifier());
-          try {
-            TezUtilsInternal.setHadoopCallerContext(appContext.getHadoopShim(),
-                initializerWrapper.vertexId);
-            return initializerWrapper.getInitializer().initialize();
-          } finally {
-            appContext.getHadoopShim().clearHadoopCallerContext();
-          }
-        }
-      });
-      return events;
-    }
-  }
-
-  @SuppressWarnings("rawtypes")
-  @VisibleForTesting
-  private static class InputInitializerCallback implements
-      FutureCallback<List<Event>> {
-
-    private final InitializerWrapper initializer;
-    private final EventHandler eventHandler;
-    private final TezVertexID vertexID;
-
-    public InputInitializerCallback(InitializerWrapper initializer,
-        EventHandler eventHandler, TezVertexID vertexID) {
-      this.initializer = initializer;
-      this.eventHandler = eventHandler;
-      this.vertexID = vertexID;
-    }
-
-    @SuppressWarnings("unchecked")
-    @Override
-    public void onSuccess(List<Event> result) {
-      initializer.setComplete();
-      LOG.info(
-          "Succeeded InputInitializer for Input: " + 
initializer.getInput().getName() +
-              " on vertex " + initializer.getVertexLogIdentifier());
-      eventHandler.handle(new VertexEventRootInputInitialized(vertexID,
-          initializer.getInput().getName(), result));
-    }
-
-    @SuppressWarnings("unchecked")
-    @Override
-    public void onFailure(Throwable t) {
-      // catch real root cause of failure, it would throw 
UndeclaredThrowableException
-      // if using UGI.doAs
-      if (t instanceof UndeclaredThrowableException) {
-        t = t.getCause();
-      }
-      initializer.setComplete();
-      LOG.info(
-          "Failed InputInitializer for Input: " + 
initializer.getInput().getName() +
-              " on vertex " + initializer.getVertexLogIdentifier());
-      eventHandler
-          .handle(new VertexEventRootInputFailed(vertexID, 
initializer.getInput().getName(),
-              new AMUserCodeException(Source.InputInitializer,t)));
-    }
+    isStopped = true;
   }
 
   @VisibleForTesting
diff --git 
a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java 
b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
index 85ae38d..db0cd46 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
@@ -2416,7 +2416,7 @@ public class VertexImpl implements 
org.apache.tez.dag.app.dag.Vertex, EventHandl
     }
   }
 
-  VertexState finished(VertexState finalState,
+  public VertexState finished(VertexState finalState,
       VertexTerminationCause termCause, String diag) {
     if (finishTime == 0) setFinishTime();
     if (termCause != null) {
@@ -3073,13 +3073,7 @@ public class VertexImpl implements 
org.apache.tez.dag.app.dag.Vertex, EventHandl
         if (vertex.inputsWithInitializers != null) {
           if (vertex.recoveryData == null || 
!vertex.recoveryData.shouldSkipInit()) {
             LOG.info("Vertex will initialize from input initializer. " + 
vertex.logIdentifier);
-            try {
-              vertex.setupInputInitializerManager();
-            } catch (TezException e) {
-              String msg = "Fail to create InputInitializerManager, " + 
ExceptionUtils.getStackTrace(e);
-              LOG.info(msg);
-              return vertex.finished(VertexState.FAILED, 
VertexTerminationCause.INIT_FAILURE, msg);
-            }
+            vertex.setupInputInitializerManager();
           }
           return VertexState.INITIALIZING;
         } else {
@@ -3112,13 +3106,7 @@ public class VertexImpl implements 
org.apache.tez.dag.app.dag.Vertex, EventHandl
         if (vertex.inputsWithInitializers != null &&
             (vertex.recoveryData == null || 
!vertex.recoveryData.shouldSkipInit())) {
           LOG.info("Vertex will initialize from input initializer. " + 
vertex.logIdentifier);
-          try {
-            vertex.setupInputInitializerManager();
-          } catch (TezException e) {
-            String msg = "Fail to create InputInitializerManager, " + 
ExceptionUtils.getStackTrace(e);
-            LOG.error(msg);
-            return vertex.finished(VertexState.FAILED, 
VertexTerminationCause.INIT_FAILURE, msg);
-          }
+          vertex.setupInputInitializerManager();
           return VertexState.INITIALIZING;
         }
         if (!vertex.uninitializedEdges.isEmpty()) {
@@ -4255,7 +4243,7 @@ public class VertexImpl implements 
org.apache.tez.dag.app.dag.Vertex, EventHandl
     }
   }
 
-  private void setupInputInitializerManager() throws TezException {
+  private void setupInputInitializerManager() {
     rootInputInitializerManager = createRootInputInitializerManager(
         getDAG().getName(), getName(), getVertexId(),
         eventHandler, getTotalTasks(),
@@ -4270,10 +4258,7 @@ public class VertexImpl implements 
org.apache.tez.dag.app.dag.Vertex, EventHandl
     LOG.info("Starting " + inputsWithInitializers.size() + " inputInitializers 
for vertex " +
         logIdentifier);
     initWaitsForRootInitializers = true;
-    rootInputInitializerManager.runInputInitializers(inputList);
-    // Send pending rootInputInitializerEvents
-    
rootInputInitializerManager.handleInitializerEvents(pendingInitializerEvents);
-    pendingInitializerEvents.clear();
+    rootInputInitializerManager.runInputInitializers(inputList, 
pendingInitializerEvents);
   }
 
   private static class VertexStateChangedCallback
diff --git 
a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/TestRootInputInitializerManager.java
 
b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/TestRootInputInitializerManager.java
index b79b4af..01cc37f 100644
--- 
a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/TestRootInputInitializerManager.java
+++ 
b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/TestRootInputInitializerManager.java
@@ -28,8 +28,14 @@ import static org.mockito.Mockito.when;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 
 import com.google.common.collect.Lists;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.event.EventHandler;
@@ -50,10 +56,27 @@ import org.apache.tez.runtime.api.InputInitializerContext;
 import org.apache.tez.runtime.api.events.InputInitializerEvent;
 import org.apache.tez.runtime.api.impl.EventMetaData;
 import org.apache.tez.runtime.api.impl.TezEvent;
+import org.junit.After;
+import org.junit.Before;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
 
 public class TestRootInputInitializerManager {
+  ListeningExecutorService execService;
+
+  @Before
+  public void setUp() throws Exception {
+    ExecutorService rawExecutor = Executors.newCachedThreadPool(new 
ThreadFactoryBuilder()
+            .setDaemon(true).setNameFormat("Test App Shared Pool - " + 
"#%d").build());
+    execService = MoreExecutors.listeningDecorator(rawExecutor);
+  }
+
+  @After
+  public void tearDown() throws Exception {
+    if (execService != null) {
+      execService.shutdownNow();
+    }
+  }
 
   // Simple testing. No events if task doesn't succeed.
   // Also exercises path where two attempts are reported as successful via the 
stateChangeNotifier.
@@ -214,6 +237,7 @@ public class TestRootInputInitializerManager {
     AppContext appContext = mock(AppContext.class);
     doReturn(new DefaultHadoopShim()).when(appContext).getHadoopShim();
     doReturn(mock(EventHandler.class)).when(appContext).getEventHandler();
+    when(appContext.getExecService()).thenReturn(execService);
     UserGroupInformation dagUgi = 
UserGroupInformation.createRemoteUser("fakeuser");
     StateChangeNotifier stateChangeNotifier = mock(StateChangeNotifier.class);
     RootInputInitializerManager rootInputInitializerManager = new 
RootInputInitializerManager(vertex, appContext, dagUgi, stateChangeNotifier);
@@ -222,7 +246,7 @@ public class TestRootInputInitializerManager {
     InputInitializerDescriptor iid = 
InputInitializerDescriptor.create(InputInitializerForUgiTest.class.getName());
     RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor> rootInput 
=
         new RootInputLeafOutput<>("InputName", id, iid);
-    
rootInputInitializerManager.runInputInitializers(Collections.singletonList(rootInput));
+    
rootInputInitializerManager.runInputInitializers(Collections.singletonList(rootInput),
 Collections.emptyList());
 
     InputInitializerForUgiTest.awaitInitialize();
 
diff --git 
a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestDAGRecovery.java 
b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestDAGRecovery.java
index 95ea8a0..9636329 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestDAGRecovery.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestDAGRecovery.java
@@ -330,8 +330,10 @@ public class TestDAGRecovery {
     Mockito.doAnswer(new Answer() {
       public ListenableFuture<Void> answer(InvocationOnMock invocation) {
         Object[] args = invocation.getArguments();
-        CallableEvent e = (CallableEvent) args[0];
-        dispatcher.getEventHandler().handle(e);
+        if (args[0] instanceof CallableEvent) {
+          CallableEvent e = (CallableEvent) args[0];
+          dispatcher.getEventHandler().handle(e);
+        }
         return mockFuture;
       }
     }).when(execService).submit((Callable<Void>) any());
diff --git 
a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java 
b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
index 5722406..5ae9556 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
@@ -43,11 +43,15 @@ import java.util.Locale;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.locks.Condition;
 import java.util.concurrent.locks.ReentrantLock;
 
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
 import com.google.protobuf.ByteString;
 
 import org.apache.commons.lang.StringUtils;
@@ -2421,10 +2425,10 @@ public class TestVertexImpl {
               dagConf);
         }
       } else {
-        v = new VertexImpl(vertexId, vPlan, vPlan.getName(), conf,
-            dispatcher.getEventHandler(), taskCommunicatorManagerInterface,
-            clock, thh, true, appContext, locationHint, vertexGroups, 
taskSpecificLaunchCmdOption,
-            updateTracker, dagConf);
+        v = new VertexImplWithRunningInputInitializerWithExecutor(vertexId, 
vPlan, vPlan.getName(), conf,
+                dispatcher.getEventHandler(), taskCommunicatorManagerInterface,
+                clock, thh, appContext, locationHint, dispatcher, 
customInitializer, updateTracker,
+                dagConf, vertexGroups);
       }
       vertices.put(vName, v);
       vertexIdMap.put(vertexId, v);
@@ -2528,8 +2532,10 @@ public class TestVertexImpl {
     Mockito.doAnswer(new Answer() {
       public ListenableFuture<Void> answer(InvocationOnMock invocation) {
           Object[] args = invocation.getArguments();
-          CallableEvent e = (CallableEvent) args[0];
-          dispatcher.getEventHandler().handle(e);
+          if (args[0] instanceof CallableEvent) {
+            CallableEvent e = (CallableEvent) args[0];
+            dispatcher.getEventHandler().handle(e);
+          }
           return mockFuture;
       }})
     .when(execService).submit((Callable<Void>) any());
@@ -2760,12 +2766,13 @@ public class TestVertexImpl {
   }
 
   @Test(timeout=5000)
-  public void testNonExistInputInitializer() throws TezException {
+  public void testNonExistInputInitializer() throws Exception {
     setupPreDagCreation();
     dagPlan = createDAGPlanWithNonExistInputInitializer();
     setupPostDagCreation();
     VertexImpl v1 = vertices.get("vertex1");
     v1.handle(new VertexEvent(v1.getVertexId(), VertexEventType.V_INIT));
+    while (v1.getTerminationCause() == null) Thread.sleep(10);
     Assert.assertEquals(VertexState.FAILED, v1.getState());
     Assert.assertEquals(VertexTerminationCause.INIT_FAILURE, 
v1.getTerminationCause());
     Assert.assertTrue(StringUtils.join(v1.getDiagnostics(),"")
@@ -5843,6 +5850,43 @@ public class TestVertexImpl {
   }
 
   @SuppressWarnings("rawtypes")
+  private static class VertexImplWithRunningInputInitializerWithExecutor 
extends VertexImpl {
+    private RootInputInitializerManagerWithExecutor 
rootInputInitializerManager;
+
+    public VertexImplWithRunningInputInitializerWithExecutor(TezVertexID 
vertexId,
+                                                             VertexPlan 
vertexPlan, String vertexName,
+                                                             Configuration 
conf,
+                                                             EventHandler 
eventHandler,
+                                                             
TaskCommunicatorManagerInterface taskCommunicatorManagerInterface,
+                                                             Clock clock, 
TaskHeartbeatHandler thh,
+                                                             AppContext 
appContext,
+                                                             
VertexLocationHint vertexLocationHint,
+                                                             DrainDispatcher 
dispatcher,
+                                                             InputInitializer 
presetInitializer,
+                                                             
StateChangeNotifier updateTracker,
+                                                             Configuration 
dagConf,
+                                                             Map<String, 
VertexGroupInfo> vertexGroups) {
+      super(vertexId, vertexPlan, vertexName, conf, eventHandler,
+              taskCommunicatorManagerInterface, clock, thh, true,
+              appContext, vertexLocationHint, vertexGroups, 
taskSpecificLaunchCmdOption,
+              updateTracker, dagConf);
+    }
+
+    @Override
+    protected RootInputInitializerManager createRootInputInitializerManager(
+            String dagName, String vertexName, TezVertexID vertexID,
+            EventHandler eventHandler, int numTasks, int numNodes,
+            Resource taskResource, Resource totalResource) {
+      try {
+        rootInputInitializerManager = new 
RootInputInitializerManagerWithExecutor(this, this.getAppContext(), 
stateChangeNotifier);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+      return rootInputInitializerManager;
+    }
+  }
+
+  @SuppressWarnings("rawtypes")
   private static class VertexImplWithControlledInitializerManager extends 
VertexImpl {
     
     private final DrainDispatcher dispatcher;
@@ -5898,9 +5942,11 @@ public class TestVertexImpl {
         IOException {
       super(vertex, appContext, UserGroupInformation.getCurrentUser(), 
tracker);
       this.presetInitializer = presetInitializer;
+      ExecutorService rawExecutor = Executors.newCachedThreadPool(new 
ThreadFactoryBuilder()
+              .setDaemon(true).setNameFormat("Test App Shared Pool - " + 
"#%d").build());
+      this.executor = MoreExecutors.listeningDecorator(rawExecutor);
     }
 
-
     @Override
     protected InputInitializer createInitializer(
         RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor> input,
@@ -5910,6 +5956,31 @@ public class TestVertexImpl {
       }
       return presetInitializer;
     }
+
+    @Override
+    public void shutdown() {
+      super.shutdown();
+      if (executor != null) {
+        executor.shutdown();
+      }
+    }
+  }
+
+  private static class RootInputInitializerManagerWithExecutor extends 
RootInputInitializerManager {
+    public RootInputInitializerManagerWithExecutor(Vertex vertex, AppContext 
appContext, StateChangeNotifier tracker) throws IOException {
+      super(vertex, appContext, UserGroupInformation.getCurrentUser(), 
tracker);
+      ExecutorService rawExecutor = Executors.newCachedThreadPool(new 
ThreadFactoryBuilder()
+              .setDaemon(true).setNameFormat("Test App Shared Pool - " + 
"#%d").build());
+      this.executor = MoreExecutors.listeningDecorator(rawExecutor);
+    }
+
+    @Override
+    public void shutdown() {
+      super.shutdown();
+      if (executor != null) {
+        executor.shutdown();
+      }
+    }
   }
 
   @SuppressWarnings({"rawtypes", "unchecked"})
@@ -5931,11 +6002,14 @@ public class TestVertexImpl {
       this.eventHandler = eventHandler;
       this.dispatcher = dispatcher;
       this.vertexID = vertex.getVertexId();
+      ExecutorService rawExecutor = Executors.newCachedThreadPool(new 
ThreadFactoryBuilder()
+              .setDaemon(true).setNameFormat("Test App Shared Pool - " + 
"#%d").build());
+      this.executor = MoreExecutors.listeningDecorator(rawExecutor);
     }
 
     @Override
     public void runInputInitializers(
-        List<RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor>> 
inputs) {
+            List<RootInputLeafOutput<InputDescriptor, 
InputInitializerDescriptor>> inputs, List<TezEvent> pendingInitializerEvents) {
       this.inputs = inputs;
     }
 
@@ -5961,10 +6035,13 @@ public class TestVertexImpl {
     @Override
     public void shutdown() {
       hasShutDown = true;
+      if (executor != null) {
+        executor.shutdown();
+      }
     }
 
     public void failInputInitialization() throws TezException {
-      super.runInputInitializers(inputs);
+      super.runInputInitializers(inputs, Collections.emptyList());
       eventHandler.handle(new VertexEventRootInputFailed(vertexID, inputs
           .get(0).getName(),
           new AMUserCodeException(Source.InputInitializer,

Reply via email to