TEZ-2126. Add unit tests for verifying multiple schedulers, launchers, communicators. (sseth)
Project: http://git-wip-us.apache.org/repos/asf/tez/repo Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/7ef9dda7 Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/7ef9dda7 Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/7ef9dda7 Branch: refs/heads/TEZ-2003 Commit: 7ef9dda732f124a234cbe8a665deb08c065f628f Parents: 60eb5f5 Author: Siddharth Seth <[email protected]> Authored: Thu Aug 6 01:04:31 2015 -0700 Committer: Siddharth Seth <[email protected]> Committed: Fri Aug 21 18:14:41 2015 -0700 ---------------------------------------------------------------------- TEZ-2003-CHANGES.txt | 1 + .../tez/dag/api/NamedEntityDescriptor.java | 7 + .../org/apache/tez/dag/app/DAGAppMaster.java | 163 ++++---- .../dag/app/TaskAttemptListenerImpTezDag.java | 94 ++--- .../apache/tez/dag/app/dag/impl/VertexImpl.java | 9 +- .../app/launcher/ContainerLauncherRouter.java | 126 ++++--- .../dag/app/rm/TaskSchedulerEventHandler.java | 137 +++---- .../apache/tez/dag/app/MockDAGAppMaster.java | 3 +- .../apache/tez/dag/app/TestDAGAppMaster.java | 300 +++++++++++++++ .../app/TestTaskAttemptListenerImplTezDag.java | 44 ++- .../app/TestTaskAttemptListenerImplTezDag2.java | 6 +- .../dag/app/TestTaskCommunicatorManager.java | 369 +++++++++++++++++++ .../tez/dag/app/dag/impl/TestVertexImpl2.java | 279 ++++++++++++-- .../launcher/TestContainerLauncherRouter.java | 361 ++++++++++++++++++ .../app/rm/TestTaskSchedulerEventHandler.java | 330 ++++++++++++++++- .../dag/app/rm/TestTaskSchedulerHelpers.java | 4 +- 16 files changed, 1907 insertions(+), 326 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/tez/blob/7ef9dda7/TEZ-2003-CHANGES.txt ---------------------------------------------------------------------- diff --git a/TEZ-2003-CHANGES.txt b/TEZ-2003-CHANGES.txt index c7a3dcc..f921739 100644 --- a/TEZ-2003-CHANGES.txt +++ b/TEZ-2003-CHANGES.txt @@ -42,5 +42,6 @@ ALL CHANGES: TEZ-2441. Add tests for TezTaskRunner2. TEZ-2657. Add tests for client side changes - specifying plugins, etc. TEZ-2626. Fix log lines with DEBUG in messages, consolidate TEZ-2003 TODOs. + TEZ-2126. Add unit tests for verifying multiple schedulers, launchers, communicators. INCOMPATIBLE CHANGES: http://git-wip-us.apache.org/repos/asf/tez/blob/7ef9dda7/tez-api/src/main/java/org/apache/tez/dag/api/NamedEntityDescriptor.java ---------------------------------------------------------------------- diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/NamedEntityDescriptor.java b/tez-api/src/main/java/org/apache/tez/dag/api/NamedEntityDescriptor.java index 723d43f..17c8c6c 100644 --- a/tez-api/src/main/java/org/apache/tez/dag/api/NamedEntityDescriptor.java +++ b/tez-api/src/main/java/org/apache/tez/dag/api/NamedEntityDescriptor.java @@ -35,4 +35,11 @@ public class NamedEntityDescriptor<T extends NamedEntityDescriptor<T>> extends E super.setUserPayload(userPayload); return (T) this; } + + @Override + public String toString() { + boolean hasPayload = + getUserPayload() == null ? false : getUserPayload().getPayload() == null ? false : true; + return "EntityName=" + entityName + ", ClassName=" + getClassName() + ", hasPayload=" + hasPayload; + } } http://git-wip-us.apache.org/repos/asf/tez/blob/7ef9dda7/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java b/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java index 53e15e8..f88c1de 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java @@ -59,6 +59,7 @@ import java.util.regex.Pattern; import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; +import com.google.common.collect.Lists; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.GnuParser; import org.apache.commons.cli.Options; @@ -389,42 +390,16 @@ public class DAGAppMaster extends AbstractService { this.isLocal = conf.getBoolean(TezConfiguration.TEZ_LOCAL_MODE, TezConfiguration.TEZ_LOCAL_MODE_DEFAULT); - List<NamedEntityDescriptor> taskSchedulerDescriptors; - List<NamedEntityDescriptor> containerLauncherDescriptors; - List<NamedEntityDescriptor> taskCommunicatorDescriptors; - boolean tezYarnEnabled = true; - boolean uberEnabled = false; - - if (!isLocal) { - if (amPluginDescriptorProto == null) { - tezYarnEnabled = true; - uberEnabled = false; - } else { - tezYarnEnabled = amPluginDescriptorProto.getContainersEnabled(); - uberEnabled = amPluginDescriptorProto.getUberEnabled(); - } - } else { - tezYarnEnabled = false; - uberEnabled = true; - } + UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(amConf); - taskSchedulerDescriptors = parsePlugin(taskSchedulers, - (amPluginDescriptorProto == null || amPluginDescriptorProto.getTaskSchedulersCount() == 0 ? - null : - amPluginDescriptorProto.getTaskSchedulersList()), - tezYarnEnabled, uberEnabled); + List<NamedEntityDescriptor> taskSchedulerDescriptors = Lists.newLinkedList(); + List<NamedEntityDescriptor> containerLauncherDescriptors = Lists.newLinkedList(); + List<NamedEntityDescriptor> taskCommunicatorDescriptors = Lists.newLinkedList(); - containerLauncherDescriptors = parsePlugin(containerLaunchers, - (amPluginDescriptorProto == null || - amPluginDescriptorProto.getContainerLaunchersCount() == 0 ? null : - amPluginDescriptorProto.getContainerLaunchersList()), - tezYarnEnabled, uberEnabled); + parseAllPlugins(taskSchedulerDescriptors, taskSchedulers, containerLauncherDescriptors, + containerLaunchers, taskCommunicatorDescriptors, taskCommunicators, amPluginDescriptorProto, + isLocal, defaultPayload); - taskCommunicatorDescriptors = parsePlugin(taskCommunicators, - (amPluginDescriptorProto == null || - amPluginDescriptorProto.getTaskCommunicatorsCount() == 0 ? null : - amPluginDescriptorProto.getTaskCommunicatorsList()), - tezYarnEnabled, uberEnabled); LOG.info(buildPluginComponentLog(taskSchedulerDescriptors, taskSchedulers, "TaskSchedulers")); @@ -494,12 +469,11 @@ public class DAGAppMaster extends AbstractService { jobTokenSecretManager.addTokenForJob( appAttemptID.getApplicationId().toString(), sessionToken); - UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(amConf); + //service to handle requests to TaskUmbilicalProtocol taskAttemptListener = createTaskAttemptListener(context, - taskHeartbeatHandler, containerHeartbeatHandler, taskCommunicatorDescriptors, - defaultPayload, isLocal); + taskHeartbeatHandler, containerHeartbeatHandler, taskCommunicatorDescriptors); addIfService(taskAttemptListener, true); containerSignatureMatcher = createContainerSignatureMatcher(); @@ -549,7 +523,7 @@ public class DAGAppMaster extends AbstractService { this.taskSchedulerEventHandler = new TaskSchedulerEventHandler(context, clientRpcServer, dispatcher.getEventHandler(), containerSignatureMatcher, webUIService, - taskSchedulerDescriptors, defaultPayload, isLocal); + taskSchedulerDescriptors, isLocal); addIfService(taskSchedulerEventHandler, true); if (enableWebUIService()) { @@ -567,7 +541,7 @@ public class DAGAppMaster extends AbstractService { taskSchedulerEventHandler); addIfServiceDependency(taskSchedulerEventHandler, clientRpcServer); - this.containerLauncherRouter = createContainerLauncherRouter(defaultPayload, containerLauncherDescriptors, isLocal); + this.containerLauncherRouter = createContainerLauncherRouter(containerLauncherDescriptors, isLocal); addIfService(containerLauncherRouter, true); dispatcher.register(NMCommunicatorEventType.class, containerLauncherRouter); @@ -1077,12 +1051,9 @@ public class DAGAppMaster extends AbstractService { protected TaskAttemptListener createTaskAttemptListener(AppContext context, TaskHeartbeatHandler thh, ContainerHeartbeatHandler chh, - List<NamedEntityDescriptor> entityDescriptors, - UserPayload defaultUserPayload, - boolean isLocal) { + List<NamedEntityDescriptor> entityDescriptors) { TaskAttemptListener lis = - new TaskAttemptListenerImpTezDag(context, thh, chh, - entityDescriptors, defaultUserPayload, isLocal); + new TaskAttemptListenerImpTezDag(context, thh, chh, entityDescriptors); return lis; } @@ -1103,11 +1074,10 @@ public class DAGAppMaster extends AbstractService { return chh; } - protected ContainerLauncherRouter createContainerLauncherRouter(UserPayload defaultPayload, - List<NamedEntityDescriptor> containerLauncherDescriptors, + protected ContainerLauncherRouter createContainerLauncherRouter(List<NamedEntityDescriptor> containerLauncherDescriptors, boolean isLocal) throws UnknownHostException { - return new ContainerLauncherRouter(defaultPayload, context, taskAttemptListener, workingDirectory, + return new ContainerLauncherRouter(context, taskAttemptListener, workingDirectory, containerLauncherDescriptors, isLocal); } @@ -2407,41 +2377,106 @@ public class DAGAppMaster extends AbstractService { TezConfiguration.TEZ_AM_WEBSERVICE_ENABLE_DEFAULT); } - private static List<NamedEntityDescriptor> parsePlugin( - BiMap<String, Integer> pluginMap, List<TezNamedEntityDescriptorProto> namedEntityDescriptorProtos, - boolean tezYarnEnabled, boolean uberEnabled) { - int index = 0; + @VisibleForTesting + static void parseAllPlugins( + List<NamedEntityDescriptor> taskSchedulerDescriptors, BiMap<String, Integer> taskSchedulerPluginMap, + List<NamedEntityDescriptor> containerLauncherDescriptors, BiMap<String, Integer> containerLauncherPluginMap, + List<NamedEntityDescriptor> taskCommDescriptors, BiMap<String, Integer> taskCommPluginMap, + AMPluginDescriptorProto amPluginDescriptorProto, boolean isLocal, UserPayload defaultPayload) { + + boolean tezYarnEnabled; + boolean uberEnabled; + if (!isLocal) { + if (amPluginDescriptorProto == null) { + tezYarnEnabled = true; + uberEnabled = false; + } else { + tezYarnEnabled = amPluginDescriptorProto.getContainersEnabled(); + uberEnabled = amPluginDescriptorProto.getUberEnabled(); + } + } else { + tezYarnEnabled = false; + uberEnabled = true; + } + + parsePlugin(taskSchedulerDescriptors, taskSchedulerPluginMap, + (amPluginDescriptorProto == null || amPluginDescriptorProto.getTaskSchedulersCount() == 0 ? + null : + amPluginDescriptorProto.getTaskSchedulersList()), + tezYarnEnabled, uberEnabled, defaultPayload); + processSchedulerDescriptors(taskSchedulerDescriptors, isLocal, defaultPayload, taskSchedulerPluginMap); - List<NamedEntityDescriptor> resultList = new LinkedList<>(); + parsePlugin(containerLauncherDescriptors, containerLauncherPluginMap, + (amPluginDescriptorProto == null || + amPluginDescriptorProto.getContainerLaunchersCount() == 0 ? null : + amPluginDescriptorProto.getContainerLaunchersList()), + tezYarnEnabled, uberEnabled, defaultPayload); + + parsePlugin(taskCommDescriptors, taskCommPluginMap, + (amPluginDescriptorProto == null || + amPluginDescriptorProto.getTaskCommunicatorsCount() == 0 ? null : + amPluginDescriptorProto.getTaskCommunicatorsList()), + tezYarnEnabled, uberEnabled, defaultPayload); + } + + + @VisibleForTesting + static void parsePlugin(List<NamedEntityDescriptor> resultList, + BiMap<String, Integer> pluginMap, List<TezNamedEntityDescriptorProto> namedEntityDescriptorProtos, + boolean tezYarnEnabled, boolean uberEnabled, UserPayload defaultPayload) { if (tezYarnEnabled) { // Default classnames will be populated by individual components NamedEntityDescriptor r = new NamedEntityDescriptor( - TezConstants.getTezYarnServicePluginName(), null); - resultList.add(r); - pluginMap.put(TezConstants.getTezYarnServicePluginName(), index); - index++; + TezConstants.getTezYarnServicePluginName(), null).setUserPayload(defaultPayload); + addDescriptor(resultList, pluginMap, r); } if (uberEnabled) { // Default classnames will be populated by individual components NamedEntityDescriptor r = new NamedEntityDescriptor( - TezConstants.getTezUberServicePluginName(), null); - resultList.add(r); - pluginMap.put(TezConstants.getTezUberServicePluginName(), index); - index++; + TezConstants.getTezUberServicePluginName(), null).setUserPayload(defaultPayload); + addDescriptor(resultList, pluginMap, r); } if (namedEntityDescriptorProtos != null) { for (TezNamedEntityDescriptorProto namedEntityDescriptorProto : namedEntityDescriptorProtos) { - resultList.add(DagTypeConverters - .convertNamedDescriptorFromProto(namedEntityDescriptorProto)); - pluginMap.put(resultList.get(index).getEntityName(), index); - index++; + NamedEntityDescriptor namedEntityDescriptor = DagTypeConverters + .convertNamedDescriptorFromProto(namedEntityDescriptorProto); + addDescriptor(resultList, pluginMap, namedEntityDescriptor); + } + } + } + + @VisibleForTesting + static void addDescriptor(List<NamedEntityDescriptor> list, BiMap<String, Integer> pluginMap, + NamedEntityDescriptor namedEntityDescriptor) { + list.add(namedEntityDescriptor); + pluginMap.put(list.get(list.size() - 1).getEntityName(), list.size() - 1); + } + + @VisibleForTesting + static void processSchedulerDescriptors(List<NamedEntityDescriptor> descriptors, boolean isLocal, + UserPayload defaultPayload, + BiMap<String, Integer> schedulerPluginMap) { + if (isLocal) { + Preconditions.checkState(descriptors.size() == 1 && + descriptors.get(0).getEntityName().equals(TezConstants.getTezUberServicePluginName())); + } else { + boolean foundYarn = false; + for (int i = 0; i < descriptors.size(); i++) { + if (descriptors.get(i).getEntityName().equals(TezConstants.getTezYarnServicePluginName())) { + foundYarn = true; + } + } + if (!foundYarn) { + NamedEntityDescriptor yarnDescriptor = + new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null) + .setUserPayload(defaultPayload); + addDescriptor(descriptors, schedulerPluginMap, yarnDescriptor); } } - return resultList; } String buildPluginComponentLog(List<NamedEntityDescriptor> namedEntityDescriptors, BiMap<String, Integer> map, http://git-wip-us.apache.org/repos/asf/tez/blob/7ef9dda7/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java index 941e583..7b97738 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java @@ -27,7 +27,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Lists; +import com.google.common.base.Preconditions; import org.apache.commons.collections4.ListUtils; import org.apache.tez.dag.api.NamedEntityDescriptor; import org.apache.tez.dag.api.TezConstants; @@ -102,35 +102,19 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements public TaskAttemptListenerImpTezDag(AppContext context, TaskHeartbeatHandler thh, ContainerHeartbeatHandler chh, - List<NamedEntityDescriptor> taskCommunicatorDescriptors, - UserPayload defaultUserPayload, - boolean isPureLocalMode) { + List<NamedEntityDescriptor> taskCommunicatorDescriptors) { super(TaskAttemptListenerImpTezDag.class.getName()); this.context = context; this.taskHeartbeatHandler = thh; this.containerHeartbeatHandler = chh; - if (taskCommunicatorDescriptors == null || taskCommunicatorDescriptors.isEmpty()) { - if (isPureLocalMode) { - taskCommunicatorDescriptors = Lists.newArrayList(new NamedEntityDescriptor( - TezConstants.getTezUberServicePluginName(), null).setUserPayload(defaultUserPayload)); - } else { - taskCommunicatorDescriptors = Lists.newArrayList(new NamedEntityDescriptor( - TezConstants.getTezYarnServicePluginName(), null).setUserPayload(defaultUserPayload)); - } - } + Preconditions.checkArgument( + taskCommunicatorDescriptors != null && !taskCommunicatorDescriptors.isEmpty(), + "TaskCommunicators must be specified"); this.taskCommunicators = new TaskCommunicator[taskCommunicatorDescriptors.size()]; this.taskCommunicatorContexts = new TaskCommunicatorContext[taskCommunicatorDescriptors.size()]; this.taskCommunicatorServiceWrappers = new ServicePluginLifecycleAbstractService[taskCommunicatorDescriptors.size()]; for (int i = 0 ; i < taskCommunicatorDescriptors.size() ; i++) { - UserPayload userPayload; - if (taskCommunicatorDescriptors.get(i).getEntityName() - .equals(TezConstants.getTezYarnServicePluginName()) || - taskCommunicatorDescriptors.get(i).getEntityName() - .equals(TezConstants.getTezUberServicePluginName())) { - userPayload = defaultUserPayload; - } else { - userPayload = taskCommunicatorDescriptors.get(i).getUserPayload(); - } + UserPayload userPayload = taskCommunicatorDescriptors.get(i).getUserPayload(); taskCommunicatorContexts[i] = new TaskCommunicatorContextImpl(context, this, userPayload, i); taskCommunicators[i] = createTaskCommunicator(taskCommunicatorDescriptors.get(i), i); taskCommunicatorServiceWrappers[i] = new ServicePluginLifecycleAbstractService(taskCommunicators[i]); @@ -154,36 +138,54 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements } } - private TaskCommunicator createTaskCommunicator(NamedEntityDescriptor taskCommDescriptor, int taskCommIndex) { + @VisibleForTesting + TaskCommunicator createTaskCommunicator(NamedEntityDescriptor taskCommDescriptor, + int taskCommIndex) { if (taskCommDescriptor.getEntityName().equals(TezConstants.getTezYarnServicePluginName())) { - LOG.info("Using Default Task Communicator"); - return createTezTaskCommunicator(taskCommunicatorContexts[taskCommIndex]); - } else if (taskCommDescriptor.getEntityName().equals(TezConstants.getTezUberServicePluginName())) { - LOG.info("Using Default Local Task Communicator"); - return new TezLocalTaskCommunicatorImpl(taskCommunicatorContexts[taskCommIndex]); + return createDefaultTaskCommunicator(taskCommunicatorContexts[taskCommIndex]); + } else if (taskCommDescriptor.getEntityName() + .equals(TezConstants.getTezUberServicePluginName())) { + return createUberTaskCommunicator(taskCommunicatorContexts[taskCommIndex]); } else { - LOG.info("Using TaskCommunicator {}:{} " + taskCommDescriptor.getEntityName(), taskCommDescriptor.getClassName()); - Class<? extends TaskCommunicator> taskCommClazz = (Class<? extends TaskCommunicator>) ReflectionUtils - .getClazz(taskCommDescriptor.getClassName()); - try { - Constructor<? extends TaskCommunicator> ctor = taskCommClazz.getConstructor(TaskCommunicatorContext.class); - ctor.setAccessible(true); - return ctor.newInstance(taskCommunicatorContexts[taskCommIndex]); - } catch (NoSuchMethodException e) { - throw new TezUncheckedException(e); - } catch (InvocationTargetException e) { - throw new TezUncheckedException(e); - } catch (InstantiationException e) { - throw new TezUncheckedException(e); - } catch (IllegalAccessException e) { - throw new TezUncheckedException(e); - } + return createCustomTaskCommunicator(taskCommunicatorContexts[taskCommIndex], + taskCommDescriptor); } } @VisibleForTesting - protected TezTaskCommunicatorImpl createTezTaskCommunicator(TaskCommunicatorContext context) { - return new TezTaskCommunicatorImpl(context); + TaskCommunicator createDefaultTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext) { + LOG.info("Using Default Task Communicator"); + return new TezTaskCommunicatorImpl(taskCommunicatorContext); + } + + @VisibleForTesting + TaskCommunicator createUberTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext) { + LOG.info("Using Default Local Task Communicator"); + return new TezLocalTaskCommunicatorImpl(taskCommunicatorContext); + } + + @VisibleForTesting + TaskCommunicator createCustomTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext, + NamedEntityDescriptor taskCommDescriptor) { + LOG.info("Using TaskCommunicator {}:{} " + taskCommDescriptor.getEntityName(), + taskCommDescriptor.getClassName()); + Class<? extends TaskCommunicator> taskCommClazz = + (Class<? extends TaskCommunicator>) ReflectionUtils + .getClazz(taskCommDescriptor.getClassName()); + try { + Constructor<? extends TaskCommunicator> ctor = + taskCommClazz.getConstructor(TaskCommunicatorContext.class); + ctor.setAccessible(true); + return ctor.newInstance(taskCommunicatorContext); + } catch (NoSuchMethodException e) { + throw new TezUncheckedException(e); + } catch (InvocationTargetException e) { + throw new TezUncheckedException(e); + } catch (InstantiationException e) { + throw new TezUncheckedException(e); + } catch (IllegalAccessException e) { + throw new TezUncheckedException(e); + } } public TaskHeartbeatResponse heartbeat(TaskHeartbeatRequest request) http://git-wip-us.apache.org/repos/asf/tez/blob/7ef9dda7/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java index 2e8f218..3cc439f 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java @@ -233,9 +233,12 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex, EventHandl private final boolean isSpeculationEnabled; - private final int taskSchedulerIdentifier; - private final int containerLauncherIdentifier; - private final int taskCommunicatorIdentifier; + @VisibleForTesting + final int taskSchedulerIdentifier; + @VisibleForTesting + final int containerLauncherIdentifier; + @VisibleForTesting + final int taskCommunicatorIdentifier; //fields initialized in init http://git-wip-us.apache.org/repos/asf/tez/blob/7ef9dda7/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherRouter.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherRouter.java b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherRouter.java index 2d56bfe..57b4aee 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherRouter.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherRouter.java @@ -20,7 +20,7 @@ import java.net.UnknownHostException; import java.util.List; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Lists; +import com.google.common.base.Preconditions; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.service.AbstractService; import org.apache.hadoop.yarn.event.EventHandler; @@ -48,8 +48,10 @@ public class ContainerLauncherRouter extends AbstractService static final Logger LOG = LoggerFactory.getLogger(ContainerLauncherImpl.class); - private final ContainerLauncher containerLaunchers[]; - private final ContainerLauncherContext containerLauncherContexts[]; + @VisibleForTesting + final ContainerLauncher containerLaunchers[]; + @VisibleForTesting + final ContainerLauncherContext containerLauncherContexts[]; protected final ServicePluginLifecycleAbstractService[] containerLauncherServiceWrappers; private final AppContext appContext; @@ -64,7 +66,7 @@ public class ContainerLauncherRouter extends AbstractService } // Accepting conf to setup final parameters, if required. - public ContainerLauncherRouter(UserPayload defaultUserPayload, AppContext context, + public ContainerLauncherRouter(AppContext context, TaskAttemptListener taskAttemptListener, String workingDirectory, List<NamedEntityDescriptor> containerLauncherDescriptors, @@ -72,79 +74,91 @@ public class ContainerLauncherRouter extends AbstractService super(ContainerLauncherRouter.class.getName()); this.appContext = context; - if (containerLauncherDescriptors == null || containerLauncherDescriptors.isEmpty()) { - if (isPureLocalMode) { - containerLauncherDescriptors = Lists.newArrayList(new NamedEntityDescriptor( - TezConstants.getTezUberServicePluginName(), null).setUserPayload(defaultUserPayload)); - } else { - containerLauncherDescriptors = Lists.newArrayList(new NamedEntityDescriptor( - TezConstants.getTezYarnServicePluginName(), null).setUserPayload(defaultUserPayload)); - } - } + Preconditions.checkArgument( + containerLauncherDescriptors != null && !containerLauncherDescriptors.isEmpty(), + "ContainerLauncherDescriptors must be specified"); containerLauncherContexts = new ContainerLauncherContext[containerLauncherDescriptors.size()]; containerLaunchers = new ContainerLauncher[containerLauncherDescriptors.size()]; containerLauncherServiceWrappers = new ServicePluginLifecycleAbstractService[containerLauncherDescriptors.size()]; for (int i = 0; i < containerLauncherDescriptors.size(); i++) { - UserPayload userPayload; - if (containerLauncherDescriptors.get(i).getEntityName() - .equals(TezConstants.getTezYarnServicePluginName()) || - containerLauncherDescriptors.get(i).getEntityName() - .equals(TezConstants.getTezUberServicePluginName())) { - userPayload = defaultUserPayload; - } else { - userPayload = containerLauncherDescriptors.get(i).getUserPayload(); - } + UserPayload userPayload = containerLauncherDescriptors.get(i).getUserPayload(); ContainerLauncherContext containerLauncherContext = new ContainerLauncherContextImpl(context, taskAttemptListener, userPayload); containerLauncherContexts[i] = containerLauncherContext; containerLaunchers[i] = createContainerLauncher(containerLauncherDescriptors.get(i), context, - containerLauncherContext, taskAttemptListener, workingDirectory, isPureLocalMode); + containerLauncherContext, taskAttemptListener, workingDirectory, i, isPureLocalMode); containerLauncherServiceWrappers[i] = new ServicePluginLifecycleAbstractService(containerLaunchers[i]); } } - private ContainerLauncher createContainerLauncher(NamedEntityDescriptor containerLauncherDescriptor, - AppContext context, - ContainerLauncherContext containerLauncherContext, - TaskAttemptListener taskAttemptListener, - String workingDirectory, - boolean isPureLocalMode) throws + @VisibleForTesting + ContainerLauncher createContainerLauncher( + NamedEntityDescriptor containerLauncherDescriptor, + AppContext context, + ContainerLauncherContext containerLauncherContext, + TaskAttemptListener taskAttemptListener, + String workingDirectory, + int containerLauncherIndex, + boolean isPureLocalMode) throws UnknownHostException { if (containerLauncherDescriptor.getEntityName().equals( TezConstants.getTezYarnServicePluginName())) { - LOG.info("Creating DefaultContainerLauncher"); - return new ContainerLauncherImpl(containerLauncherContext); + return createYarnContainerLauncher(containerLauncherContext); } else if (containerLauncherDescriptor.getEntityName() .equals(TezConstants.getTezUberServicePluginName())) { - LOG.info("Creating LocalContainerLauncher"); - // TODO Post TEZ-2003. LocalContainerLauncher is special cased, since it makes use of - // extensive internals which are only available at runtime. Will likely require - // some kind of runtime binding of parameters in the payload to work correctly. - return - new LocalContainerLauncher(containerLauncherContext, context, taskAttemptListener, workingDirectory, isPureLocalMode); + return createUberContainerLauncher(containerLauncherContext, context, taskAttemptListener, + workingDirectory, isPureLocalMode); } else { - LOG.info("Creating container launcher {}:{} ", containerLauncherDescriptor.getEntityName(), containerLauncherDescriptor.getClassName()); - Class<? extends ContainerLauncher> containerLauncherClazz = - (Class<? extends ContainerLauncher>) ReflectionUtils.getClazz( - containerLauncherDescriptor.getClassName()); - try { - Constructor<? extends ContainerLauncher> ctor = containerLauncherClazz - .getConstructor(ContainerLauncherContext.class); - ctor.setAccessible(true); - return ctor.newInstance(containerLauncherContext); - } catch (NoSuchMethodException e) { - throw new TezUncheckedException(e); - } catch (InvocationTargetException e) { - throw new TezUncheckedException(e); - } catch (InstantiationException e) { - throw new TezUncheckedException(e); - } catch (IllegalAccessException e) { - throw new TezUncheckedException(e); - } + return createCustomContainerLauncher(containerLauncherContext, containerLauncherDescriptor); + } + } + + @VisibleForTesting + ContainerLauncher createYarnContainerLauncher(ContainerLauncherContext containerLauncherContext) { + LOG.info("Creating DefaultContainerLauncher"); + return new ContainerLauncherImpl(containerLauncherContext); + } + + @VisibleForTesting + ContainerLauncher createUberContainerLauncher(ContainerLauncherContext containerLauncherContext, + AppContext context, + TaskAttemptListener taskAttemptListener, + String workingDirectory, + boolean isPureLocalMode) throws + UnknownHostException { + LOG.info("Creating LocalContainerLauncher"); + // TODO Post TEZ-2003. LocalContainerLauncher is special cased, since it makes use of + // extensive internals which are only available at runtime. Will likely require + // some kind of runtime binding of parameters in the payload to work correctly. + return + new LocalContainerLauncher(containerLauncherContext, context, taskAttemptListener, + workingDirectory, isPureLocalMode); + } + + @VisibleForTesting + ContainerLauncher createCustomContainerLauncher(ContainerLauncherContext containerLauncherContext, + NamedEntityDescriptor containerLauncherDescriptor) { + LOG.info("Creating container launcher {}:{} ", containerLauncherDescriptor.getEntityName(), + containerLauncherDescriptor.getClassName()); + Class<? extends ContainerLauncher> containerLauncherClazz = + (Class<? extends ContainerLauncher>) ReflectionUtils.getClazz( + containerLauncherDescriptor.getClassName()); + try { + Constructor<? extends ContainerLauncher> ctor = containerLauncherClazz + .getConstructor(ContainerLauncherContext.class); + ctor.setAccessible(true); + return ctor.newInstance(containerLauncherContext); + } catch (NoSuchMethodException e) { + throw new TezUncheckedException(e); + } catch (InvocationTargetException e) { + throw new TezUncheckedException(e); + } catch (InstantiationException e) { + throw new TezUncheckedException(e); + } catch (IllegalAccessException e) { + throw new TezUncheckedException(e); } - // TODO TEZ-2118 Handle routing to multiple launchers } @Override http://git-wip-us.apache.org/repos/asf/tez/blob/7ef9dda7/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerEventHandler.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerEventHandler.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerEventHandler.java index d178d61..4d710fa 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerEventHandler.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerEventHandler.java @@ -22,7 +22,6 @@ import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.concurrent.BlockingQueue; @@ -34,10 +33,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.apache.hadoop.classification.InterfaceAudience; import org.apache.tez.dag.api.NamedEntityDescriptor; import org.apache.tez.dag.api.TezConstants; -import org.apache.tez.dag.api.UserPayload; import org.apache.tez.dag.app.ServicePluginLifecycleAbstractService; import org.apache.tez.serviceplugins.api.TaskScheduler; import org.apache.tez.serviceplugins.api.TaskSchedulerContext; @@ -126,9 +123,8 @@ public class TaskSchedulerEventHandler extends AbstractService implements private final boolean isPureLocalMode; // If running in non local-only mode, the YARN task scheduler will always run to take care of // registration with YARN and heartbeats to YARN. - // Splitting registration and heartbeats is not straigh-forward due to the taskScheduler being + // Splitting registration and heartbeats is not straight-forward due to the taskScheduler being // tied to a ContainerRequestType. - private final int yarnTaskSchedulerIndex; // Custom AppIds to avoid container conflicts if there's multiple sources private final long SCHEDULER_APP_ID_BASE = 111101111; private final long SCHEDULER_APP_ID_INCREMENT = 111111111; @@ -153,9 +149,10 @@ public class TaskSchedulerEventHandler extends AbstractService implements public TaskSchedulerEventHandler(AppContext appContext, DAGClientServer clientService, EventHandler eventHandler, ContainerSignatureMatcher containerSignatureMatcher, WebUIService webUI, - List<NamedEntityDescriptor> schedulerDescriptors, UserPayload defaultPayload, - boolean isPureLocalMode) { + List<NamedEntityDescriptor> schedulerDescriptors, boolean isPureLocalMode) { super(TaskSchedulerEventHandler.class.getName()); + Preconditions.checkArgument(schedulerDescriptors != null && !schedulerDescriptors.isEmpty(), + "TaskSchedulerDescriptors must be specified"); this.appContext = appContext; this.eventHandler = eventHandler; this.clientService = clientService; @@ -168,50 +165,8 @@ public class TaskSchedulerEventHandler extends AbstractService implements this.webUI.setHistoryUrl(this.historyUrl); } - // Override everything for pure local mode - if (isPureLocalMode) { - this.taskSchedulerDescriptors = new NamedEntityDescriptor[]{ - new NamedEntityDescriptor(TezConstants.getTezUberServicePluginName(), null) - .setUserPayload(defaultPayload)}; - this.yarnTaskSchedulerIndex = -1; - } else { - if (schedulerDescriptors == null || schedulerDescriptors.isEmpty()) { - this.taskSchedulerDescriptors = new NamedEntityDescriptor[]{ - new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null) - .setUserPayload(defaultPayload)}; - this.yarnTaskSchedulerIndex = 0; - } else { - // Ensure the YarnScheduler will be setup and note it's index. This will be responsible for heartbeats and YARN registration. - int foundYarnTaskSchedulerIndex = -1; - - List<NamedEntityDescriptor> schedulerDescriptorList = new LinkedList<>(); - for (int i = 0 ; i < schedulerDescriptors.size() ; i++) { - if (schedulerDescriptors.get(i).getEntityName().equals( - TezConstants.getTezYarnServicePluginName())) { - schedulerDescriptorList.add( - new NamedEntityDescriptor(schedulerDescriptors.get(i).getEntityName(), null) - .setUserPayload( - defaultPayload)); - foundYarnTaskSchedulerIndex = i; - } else if (schedulerDescriptors.get(i).getEntityName().equals( - TezConstants.getTezUberServicePluginName())) { - schedulerDescriptorList.add( - new NamedEntityDescriptor(schedulerDescriptors.get(i).getEntityName(), null) - .setUserPayload( - defaultPayload)); - } else { - schedulerDescriptorList.add(schedulerDescriptors.get(i)); - } - } - if (foundYarnTaskSchedulerIndex == -1) { - schedulerDescriptorList.add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null).setUserPayload( - defaultPayload)); - foundYarnTaskSchedulerIndex = schedulerDescriptorList.size() -1; - } - this.taskSchedulerDescriptors = schedulerDescriptorList.toArray(new NamedEntityDescriptor[schedulerDescriptorList.size()]); - this.yarnTaskSchedulerIndex = foundYarnTaskSchedulerIndex; - } - } + this.taskSchedulerDescriptors = schedulerDescriptors.toArray(new NamedEntityDescriptor[schedulerDescriptors.size()]); + taskSchedulers = new TaskScheduler[this.taskSchedulerDescriptors.length]; taskSchedulerServiceWrappers = new ServicePluginLifecycleAbstractService[this.taskSchedulerDescriptors.length]; } @@ -239,7 +194,8 @@ public class TaskSchedulerEventHandler extends AbstractService implements private ExecutorService createAppCallbackExecutorService() { return Executors.newSingleThreadExecutor( - new ThreadFactoryBuilder().setNameFormat("TaskSchedulerAppCallbackExecutor #%d").setDaemon(true) + new ThreadFactoryBuilder().setNameFormat("TaskSchedulerAppCallbackExecutor #%d") + .setDaemon(true) .build()); } @@ -428,7 +384,8 @@ public class TaskSchedulerEventHandler extends AbstractService implements event); } - private TaskScheduler createTaskScheduler(String host, int port, String trackingUrl, + @VisibleForTesting + TaskScheduler createTaskScheduler(String host, int port, String trackingUrl, AppContext appContext, NamedEntityDescriptor taskSchedulerDescriptor, long customAppIdIdentifier, @@ -436,32 +393,57 @@ public class TaskSchedulerEventHandler extends AbstractService implements TaskSchedulerContext rawContext = new TaskSchedulerContextImpl(this, appContext, schedulerId, trackingUrl, customAppIdIdentifier, host, port, taskSchedulerDescriptor.getUserPayload()); - TaskSchedulerContext wrappedContext = new TaskSchedulerContextImplWrapper(rawContext, appCallbackExecutor); + TaskSchedulerContext wrappedContext = wrapTaskSchedulerContext(rawContext); String schedulerName = taskSchedulerDescriptor.getEntityName(); if (schedulerName.equals(TezConstants.getTezYarnServicePluginName())) { - LOG.info("Creating TaskScheduler: YarnTaskSchedulerService"); - return new YarnTaskSchedulerService(wrappedContext); + return createYarnTaskScheduler(wrappedContext, schedulerId); } else if (schedulerName.equals(TezConstants.getTezUberServicePluginName())) { - LOG.info("Creating TaskScheduler: Local TaskScheduler"); - return new LocalTaskSchedulerService(wrappedContext); + return createUberTaskScheduler(wrappedContext, schedulerId); } else { - LOG.info("Creating custom TaskScheduler {}:{}", taskSchedulerDescriptor.getEntityName(), taskSchedulerDescriptor.getClassName()); - Class<? extends TaskScheduler> taskSchedulerClazz = - (Class<? extends TaskScheduler>) ReflectionUtils.getClazz(taskSchedulerDescriptor.getClassName()); - try { - Constructor<? extends TaskScheduler> ctor = taskSchedulerClazz - .getConstructor(TaskSchedulerContext.class); - ctor.setAccessible(true); - return ctor.newInstance(wrappedContext); - } catch (NoSuchMethodException e) { - throw new TezUncheckedException(e); - } catch (InvocationTargetException e) { - throw new TezUncheckedException(e); - } catch (InstantiationException e) { - throw new TezUncheckedException(e); - } catch (IllegalAccessException e) { - throw new TezUncheckedException(e); - } + return createCustomTaskScheduler(wrappedContext, taskSchedulerDescriptor, schedulerId); + } + } + + @VisibleForTesting + TaskSchedulerContext wrapTaskSchedulerContext(TaskSchedulerContext rawContext) { + return new TaskSchedulerContextImplWrapper(rawContext, appCallbackExecutor); + } + + @VisibleForTesting + TaskScheduler createYarnTaskScheduler(TaskSchedulerContext taskSchedulerContext, + int schedulerId) { + LOG.info("Creating TaskScheduler: YarnTaskSchedulerService"); + return new YarnTaskSchedulerService(taskSchedulerContext); + } + + @VisibleForTesting + TaskScheduler createUberTaskScheduler(TaskSchedulerContext taskSchedulerContext, + int schedulerId) { + LOG.info("Creating TaskScheduler: Local TaskScheduler"); + return new LocalTaskSchedulerService(taskSchedulerContext); + } + + TaskScheduler createCustomTaskScheduler(TaskSchedulerContext taskSchedulerContext, + NamedEntityDescriptor taskSchedulerDescriptor, + int schedulerId) { + LOG.info("Creating custom TaskScheduler {}:{}", taskSchedulerDescriptor.getEntityName(), + taskSchedulerDescriptor.getClassName()); + Class<? extends TaskScheduler> taskSchedulerClazz = + (Class<? extends TaskScheduler>) ReflectionUtils + .getClazz(taskSchedulerDescriptor.getClassName()); + try { + Constructor<? extends TaskScheduler> ctor = taskSchedulerClazz + .getConstructor(TaskSchedulerContext.class); + ctor.setAccessible(true); + return ctor.newInstance(taskSchedulerContext); + } catch (NoSuchMethodException e) { + throw new TezUncheckedException(e); + } catch (InvocationTargetException e) { + throw new TezUncheckedException(e); + } catch (InstantiationException e) { + throw new TezUncheckedException(e); + } catch (IllegalAccessException e) { + throw new TezUncheckedException(e); } } @@ -801,9 +783,4 @@ public class TaskSchedulerEventHandler extends AbstractService implements return historyUrl; } - @VisibleForTesting - @InterfaceAudience.Private - ExecutorService getContextExecutorService() { - return appCallbackExecutor; - } } http://git-wip-us.apache.org/repos/asf/tez/blob/7ef9dda7/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java b/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java index 0723dbc..2e6e568 100644 --- a/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java @@ -511,8 +511,7 @@ public class MockDAGAppMaster extends DAGAppMaster { // use mock container launcher for tests @Override - protected ContainerLauncherRouter createContainerLauncherRouter(final UserPayload defaultUserPayload, - List<NamedEntityDescriptor> containerLauncherDescirptors, + protected ContainerLauncherRouter createContainerLauncherRouter(List<NamedEntityDescriptor> containerLauncherDescirptors, boolean isLocal) throws UnknownHostException { return new ContainerLauncherRouter(containerLauncher, getContext()); http://git-wip-us.apache.org/repos/asf/tez/blob/7ef9dda7/tez-dag/src/test/java/org/apache/tez/dag/app/TestDAGAppMaster.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestDAGAppMaster.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestDAGAppMaster.java new file mode 100644 index 0000000..fa5d87c --- /dev/null +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestDAGAppMaster.java @@ -0,0 +1,300 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.dag.app; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.LinkedList; +import java.util.List; + +import com.google.common.base.Preconditions; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import com.google.common.collect.Lists; +import com.google.protobuf.ByteString; +import org.apache.hadoop.conf.Configuration; +import org.apache.tez.common.TezUtils; +import org.apache.tez.dag.api.NamedEntityDescriptor; +import org.apache.tez.dag.api.TezConstants; +import org.apache.tez.dag.api.UserPayload; +import org.apache.tez.dag.api.records.DAGProtos; +import org.apache.tez.dag.api.records.DAGProtos.AMPluginDescriptorProto; +import org.apache.tez.dag.api.records.DAGProtos.TezNamedEntityDescriptorProto; +import org.apache.tez.dag.api.records.DAGProtos.TezUserPayloadProto; +import org.junit.Test; + +public class TestDAGAppMaster { + + private static final String TEST_KEY = "TEST_KEY"; + private static final String TEST_VAL = "TEST_VAL"; + private static final String TS_NAME = "TS"; + private static final String CL_NAME = "CL"; + private static final String TC_NAME = "TC"; + private static final String CLASS_SUFFIX = "_CLASS"; + + @Test(timeout = 5000) + public void testPluginParsing() throws IOException { + BiMap<String, Integer> pluginMap = HashBiMap.create(); + Configuration conf = new Configuration(false); + conf.set("testkey", "testval"); + UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf); + + List<TezNamedEntityDescriptorProto> entityDescriptors = new LinkedList<>(); + List<NamedEntityDescriptor> entities; + + // Test empty descriptor list, yarn enabled + pluginMap.clear(); + entities = new LinkedList<>(); + DAGAppMaster.parsePlugin(entities, pluginMap, null, true, false, defaultPayload); + assertEquals(1, pluginMap.size()); + assertEquals(1, entities.size()); + assertTrue(pluginMap.containsKey(TezConstants.getTezYarnServicePluginName())); + assertTrue(0 == pluginMap.get(TezConstants.getTezYarnServicePluginName())); + assertEquals("testval", + TezUtils.createConfFromUserPayload(entities.get(0).getUserPayload()).get("testkey")); + + // Test empty descriptor list, uber enabled + pluginMap.clear(); + entities = new LinkedList<>(); + DAGAppMaster.parsePlugin(entities, pluginMap, null, false, true, defaultPayload); + assertEquals(1, pluginMap.size()); + assertEquals(1, entities.size()); + assertTrue(pluginMap.containsKey(TezConstants.getTezUberServicePluginName())); + assertTrue(0 == pluginMap.get(TezConstants.getTezUberServicePluginName())); + assertEquals("testval", + TezUtils.createConfFromUserPayload(entities.get(0).getUserPayload()).get("testkey")); + + // Test empty descriptor list, yarn enabled, uber enabled + pluginMap.clear(); + entities = new LinkedList<>(); + DAGAppMaster.parsePlugin(entities, pluginMap, null, true, true, defaultPayload); + assertEquals(2, pluginMap.size()); + assertEquals(2, entities.size()); + assertTrue(pluginMap.containsKey(TezConstants.getTezYarnServicePluginName())); + assertTrue(0 == pluginMap.get(TezConstants.getTezYarnServicePluginName())); + assertTrue(pluginMap.containsKey(TezConstants.getTezUberServicePluginName())); + assertTrue(1 == pluginMap.get(TezConstants.getTezUberServicePluginName())); + + + String pluginName = "d1"; + ByteBuffer bb = ByteBuffer.allocate(4); + bb.putInt(0, 3); + TezNamedEntityDescriptorProto d1 = + TezNamedEntityDescriptorProto.newBuilder().setName(pluginName).setEntityDescriptor( + DAGProtos.TezEntityDescriptorProto.newBuilder().setClassName("d1Class") + .setTezUserPayload( + TezUserPayloadProto.newBuilder() + .setUserPayload(ByteString.copyFrom(bb)))).build(); + entityDescriptors.add(d1); + + // Test descriptor, no yarn, no uber + pluginMap.clear(); + entities = new LinkedList<>(); + DAGAppMaster.parsePlugin(entities, pluginMap, entityDescriptors, false, false, defaultPayload); + assertEquals(1, pluginMap.size()); + assertEquals(1, entities.size()); + assertTrue(pluginMap.containsKey(pluginName)); + assertTrue(0 == pluginMap.get(pluginName)); + + // Test descriptor, yarn and uber + pluginMap.clear(); + entities = new LinkedList<>(); + DAGAppMaster.parsePlugin(entities, pluginMap, entityDescriptors, true, true, defaultPayload); + assertEquals(3, pluginMap.size()); + assertEquals(3, entities.size()); + assertTrue(pluginMap.containsKey(TezConstants.getTezYarnServicePluginName())); + assertTrue(0 == pluginMap.get(TezConstants.getTezYarnServicePluginName())); + assertTrue(pluginMap.containsKey(TezConstants.getTezUberServicePluginName())); + assertTrue(1 == pluginMap.get(TezConstants.getTezUberServicePluginName())); + assertTrue(pluginMap.containsKey(pluginName)); + assertTrue(2 == pluginMap.get(pluginName)); + entityDescriptors.clear(); + } + + + @Test(timeout = 5000) + public void testParseAllPluginsNoneSpecified() throws IOException { + Configuration conf = new Configuration(false); + conf.set(TEST_KEY, TEST_VAL); + UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf); + + List<NamedEntityDescriptor> tsDescriptors; + BiMap<String, Integer> tsMap; + List<NamedEntityDescriptor> clDescriptors; + BiMap<String, Integer> clMap; + List<NamedEntityDescriptor> tcDescriptors; + BiMap<String, Integer> tcMap; + + + // No plugins. Non local + tsDescriptors = Lists.newLinkedList(); + tsMap = HashBiMap.create(); + clDescriptors = Lists.newLinkedList(); + clMap = HashBiMap.create(); + tcDescriptors = Lists.newLinkedList(); + tcMap = HashBiMap.create(); + DAGAppMaster.parseAllPlugins(tsDescriptors, tsMap, clDescriptors, clMap, tcDescriptors, tcMap, + null, false, defaultPayload); + verifyDescAndMap(tsDescriptors, tsMap, 1, true, TezConstants.getTezYarnServicePluginName()); + verifyDescAndMap(clDescriptors, clMap, 1, true, TezConstants.getTezYarnServicePluginName()); + verifyDescAndMap(tcDescriptors, tcMap, 1, true, TezConstants.getTezYarnServicePluginName()); + + // No plugins. Local + tsDescriptors = Lists.newLinkedList(); + tsMap = HashBiMap.create(); + clDescriptors = Lists.newLinkedList(); + clMap = HashBiMap.create(); + tcDescriptors = Lists.newLinkedList(); + tcMap = HashBiMap.create(); + DAGAppMaster.parseAllPlugins(tsDescriptors, tsMap, clDescriptors, clMap, tcDescriptors, tcMap, + null, true, defaultPayload); + verifyDescAndMap(tsDescriptors, tsMap, 1, true, TezConstants.getTezUberServicePluginName()); + verifyDescAndMap(clDescriptors, clMap, 1, true, TezConstants.getTezUberServicePluginName()); + verifyDescAndMap(tcDescriptors, tcMap, 1, true, TezConstants.getTezUberServicePluginName()); + } + + @Test(timeout = 5000) + public void testParseAllPluginsOnlyCustomSpecified() throws IOException { + Configuration conf = new Configuration(false); + conf.set(TEST_KEY, TEST_VAL); + UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf); + TezUserPayloadProto payloadProto = TezUserPayloadProto.newBuilder() + .setUserPayload(ByteString.copyFrom(defaultPayload.getPayload())).build(); + + AMPluginDescriptorProto proto = createAmPluginDescriptor(false, false, true, payloadProto); + + List<NamedEntityDescriptor> tsDescriptors; + BiMap<String, Integer> tsMap; + List<NamedEntityDescriptor> clDescriptors; + BiMap<String, Integer> clMap; + List<NamedEntityDescriptor> tcDescriptors; + BiMap<String, Integer> tcMap; + + + // Only plugin, Yarn. + tsDescriptors = Lists.newLinkedList(); + tsMap = HashBiMap.create(); + clDescriptors = Lists.newLinkedList(); + clMap = HashBiMap.create(); + tcDescriptors = Lists.newLinkedList(); + tcMap = HashBiMap.create(); + DAGAppMaster.parseAllPlugins(tsDescriptors, tsMap, clDescriptors, clMap, tcDescriptors, tcMap, + proto, false, defaultPayload); + verifyDescAndMap(tsDescriptors, tsMap, 2, true, TS_NAME, + TezConstants.getTezYarnServicePluginName()); + verifyDescAndMap(clDescriptors, clMap, 1, true, CL_NAME); + verifyDescAndMap(tcDescriptors, tcMap, 1, true, TC_NAME); + assertEquals(TS_NAME + CLASS_SUFFIX, tsDescriptors.get(0).getClassName()); + assertEquals(CL_NAME + CLASS_SUFFIX, clDescriptors.get(0).getClassName()); + assertEquals(TC_NAME + CLASS_SUFFIX, tcDescriptors.get(0).getClassName()); + } + + @Test(timeout = 5000) + public void testParseAllPluginsCustomAndYarnSpecified() throws IOException { + Configuration conf = new Configuration(false); + conf.set(TEST_KEY, TEST_VAL); + UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf); + TezUserPayloadProto payloadProto = TezUserPayloadProto.newBuilder() + .setUserPayload(ByteString.copyFrom(defaultPayload.getPayload())).build(); + + AMPluginDescriptorProto proto = createAmPluginDescriptor(true, false, true, payloadProto); + + List<NamedEntityDescriptor> tsDescriptors; + BiMap<String, Integer> tsMap; + List<NamedEntityDescriptor> clDescriptors; + BiMap<String, Integer> clMap; + List<NamedEntityDescriptor> tcDescriptors; + BiMap<String, Integer> tcMap; + + + // Only plugin, Yarn. + tsDescriptors = Lists.newLinkedList(); + tsMap = HashBiMap.create(); + clDescriptors = Lists.newLinkedList(); + clMap = HashBiMap.create(); + tcDescriptors = Lists.newLinkedList(); + tcMap = HashBiMap.create(); + DAGAppMaster.parseAllPlugins(tsDescriptors, tsMap, clDescriptors, clMap, tcDescriptors, tcMap, + proto, false, defaultPayload); + verifyDescAndMap(tsDescriptors, tsMap, 2, true, TezConstants.getTezYarnServicePluginName(), + TS_NAME); + verifyDescAndMap(clDescriptors, clMap, 2, true, TezConstants.getTezYarnServicePluginName(), + CL_NAME); + verifyDescAndMap(tcDescriptors, tcMap, 2, true, TezConstants.getTezYarnServicePluginName(), + TC_NAME); + assertNull(tsDescriptors.get(0).getClassName()); + assertNull(clDescriptors.get(0).getClassName()); + assertNull(tcDescriptors.get(0).getClassName()); + assertEquals(TS_NAME + CLASS_SUFFIX, tsDescriptors.get(1).getClassName()); + assertEquals(CL_NAME + CLASS_SUFFIX, clDescriptors.get(1).getClassName()); + assertEquals(TC_NAME + CLASS_SUFFIX, tcDescriptors.get(1).getClassName()); + } + + private void verifyDescAndMap(List<NamedEntityDescriptor> descriptors, BiMap<String, Integer> map, + int numExpected, boolean verifyPayload, + String... expectedNames) throws + IOException { + Preconditions.checkArgument(expectedNames.length == numExpected); + assertEquals(numExpected, descriptors.size()); + assertEquals(numExpected, map.size()); + for (int i = 0; i < numExpected; i++) { + assertEquals(expectedNames[i], descriptors.get(i).getEntityName()); + if (verifyPayload) { + assertEquals(TEST_VAL, + TezUtils.createConfFromUserPayload(descriptors.get(0).getUserPayload()).get(TEST_KEY)); + } + assertTrue(map.get(expectedNames[i]) == i); + assertTrue(map.inverse().get(i) == expectedNames[i]); + } + } + + private AMPluginDescriptorProto createAmPluginDescriptor(boolean enableYarn, boolean enableUber, + boolean addCustom, + TezUserPayloadProto payloadProto) { + AMPluginDescriptorProto.Builder builder = AMPluginDescriptorProto.newBuilder() + .setUberEnabled(enableUber) + .setContainersEnabled(enableYarn); + if (addCustom) { + builder.addTaskSchedulers( + TezNamedEntityDescriptorProto.newBuilder() + .setName(TS_NAME) + .setEntityDescriptor( + DAGProtos.TezEntityDescriptorProto.newBuilder() + .setClassName(TS_NAME + CLASS_SUFFIX) + .setTezUserPayload(payloadProto))) + .addContainerLaunchers( + TezNamedEntityDescriptorProto.newBuilder() + .setName(CL_NAME) + .setEntityDescriptor( + DAGProtos.TezEntityDescriptorProto.newBuilder() + .setClassName(CL_NAME + CLASS_SUFFIX) + .setTezUserPayload(payloadProto))) + .addTaskCommunicators( + TezNamedEntityDescriptorProto.newBuilder() + .setName(TC_NAME) + .setEntityDescriptor( + DAGProtos.TezEntityDescriptorProto.newBuilder() + .setClassName(TC_NAME + CLASS_SUFFIX) + .setTezUserPayload(payloadProto))); + } + return builder.build(); + } + + +} http://git-wip-us.apache.org/repos/asf/tez/blob/7ef9dda7/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java index 1cb69a8..639c487 100644 --- a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java @@ -34,6 +34,7 @@ import java.util.List; import java.util.Map; import java.util.Random; +import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.Text; import org.apache.hadoop.security.Credentials; @@ -52,7 +53,9 @@ import org.apache.tez.common.security.JobTokenIdentifier; import org.apache.tez.common.security.JobTokenSecretManager; import org.apache.tez.common.security.TokenCache; import org.apache.tez.dag.api.NamedEntityDescriptor; +import org.apache.tez.dag.api.TaskCommunicator; import org.apache.tez.dag.api.TezConfiguration; +import org.apache.tez.dag.api.TezConstants; import org.apache.tez.dag.api.TezUncheckedException; import org.apache.tez.dag.api.UserPayload; import org.apache.tez.serviceplugins.api.ContainerEndReason; @@ -146,7 +149,10 @@ public class TestTaskAttemptListenerImplTezDag { throw new TezUncheckedException(e); } taskAttemptListener = new TaskAttemptListenerImplForTest(appContext, - mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class), null, defaultPayload, false); + mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class), + Lists.newArrayList( + new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null) + .setUserPayload(defaultPayload))); TezTaskCommunicatorImpl taskCommunicator = (TezTaskCommunicatorImpl)taskAttemptListener.getTaskCommunicator(); TezTaskUmbilicalProtocol tezUmbilical = taskCommunicator.getUmbilical(); @@ -301,7 +307,7 @@ public class TestTaskAttemptListenerImplTezDag { // TODO TEZ-2003 Move this into TestTezTaskCommunicator. Potentially other tests as well. @Test (timeout= 5000) - public void testPortRange_NotSpecified() { + public void testPortRange_NotSpecified() throws IOException { Configuration conf = new Configuration(); JobTokenIdentifier identifier = new JobTokenIdentifier(new Text( "fakeIdentifier")); @@ -309,14 +315,11 @@ public class TestTaskAttemptListenerImplTezDag { new JobTokenSecretManager()); sessionToken.setService(identifier.getJobId()); TokenCache.setSessionToken(sessionToken, credentials); - UserPayload userPayload = null; - try { - userPayload = TezUtils.createUserPayloadFromConf(conf); - } catch (IOException e) { - throw new TezUncheckedException(e); - } + UserPayload userPayload = TezUtils.createUserPayloadFromConf(conf); taskAttemptListener = new TaskAttemptListenerImpTezDag(appContext, - mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class), null, userPayload, false); + mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class), Lists.newArrayList( + new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null) + .setUserPayload(userPayload))); // no exception happen, should started properly taskAttemptListener.init(conf); taskAttemptListener.start(); @@ -335,14 +338,12 @@ public class TestTaskAttemptListenerImplTezDag { TokenCache.setSessionToken(sessionToken, credentials); conf.set(TezConfiguration.TEZ_AM_TASK_AM_PORT_RANGE, port + "-" + port); - UserPayload userPayload = null; - try { - userPayload = TezUtils.createUserPayloadFromConf(conf); - } catch (IOException e) { - throw new TezUncheckedException(e); - } + UserPayload userPayload = TezUtils.createUserPayloadFromConf(conf); + taskAttemptListener = new TaskAttemptListenerImpTezDag(appContext, - mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class), null, userPayload, false); + mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class), Lists + .newArrayList(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null) + .setUserPayload(userPayload))); taskAttemptListener.init(conf); taskAttemptListener.start(); int resultedPort = taskAttemptListener.getTaskCommunicator(0).getAddress().getPort(); @@ -398,16 +399,13 @@ public class TestTaskAttemptListenerImplTezDag { public TaskAttemptListenerImplForTest(AppContext context, TaskHeartbeatHandler thh, ContainerHeartbeatHandler chh, - List<NamedEntityDescriptor> taskCommDescriptors, - UserPayload userPayload, - boolean isPureLocalMode) { - super(context, thh, chh, taskCommDescriptors, userPayload, - isPureLocalMode); + List<NamedEntityDescriptor> taskCommDescriptors) { + super(context, thh, chh, taskCommDescriptors); } @Override - protected TezTaskCommunicatorImpl createTezTaskCommunicator(TaskCommunicatorContext context) { - return new TezTaskCommunicatorImplForTest(context); + TaskCommunicator createDefaultTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext) { + return new TezTaskCommunicatorImplForTest(taskCommunicatorContext); } } http://git-wip-us.apache.org/repos/asf/tez/blob/7ef9dda7/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag2.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag2.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag2.java index 1c82bd8..abb5e42 100644 --- a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag2.java +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag2.java @@ -26,6 +26,7 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; +import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.security.Credentials; import org.apache.hadoop.yarn.api.records.ApplicationAccessType; @@ -37,7 +38,9 @@ import org.apache.hadoop.yarn.api.records.NodeId; import org.apache.hadoop.yarn.event.Event; import org.apache.hadoop.yarn.event.EventHandler; import org.apache.tez.common.TezUtils; +import org.apache.tez.dag.api.NamedEntityDescriptor; import org.apache.tez.dag.api.TezConfiguration; +import org.apache.tez.dag.api.TezConstants; import org.apache.tez.dag.api.TezUncheckedException; import org.apache.tez.dag.api.UserPayload; import org.apache.tez.serviceplugins.api.TaskAttemptEndReason; @@ -83,7 +86,8 @@ public class TestTaskAttemptListenerImplTezDag2 { UserPayload userPayload = TezUtils.createUserPayloadFromConf(conf); TaskAttemptListenerImpTezDag taskAttemptListener = new TaskAttemptListenerImpTezDag(appContext, mock(TaskHeartbeatHandler.class), - mock(ContainerHeartbeatHandler.class), null, userPayload, false); + mock(ContainerHeartbeatHandler.class), Lists.newArrayList(new NamedEntityDescriptor( + TezConstants.getTezYarnServicePluginName(), null).setUserPayload(userPayload))); TaskSpec taskSpec1 = mock(TaskSpec.class); TezTaskAttemptID taskAttemptId1 = mock(TezTaskAttemptID.class); http://git-wip-us.apache.org/repos/asf/tez/blob/7ef9dda7/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskCommunicatorManager.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskCommunicatorManager.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskCommunicatorManager.java new file mode 100644 index 0000000..c76aa50 --- /dev/null +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskCommunicatorManager.java @@ -0,0 +1,369 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.dag.app; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.LocalResource; +import org.apache.hadoop.yarn.api.records.NodeId; +import org.apache.tez.common.TezUtils; +import org.apache.tez.dag.api.NamedEntityDescriptor; +import org.apache.tez.dag.api.TaskCommunicator; +import org.apache.tez.dag.api.TaskCommunicatorContext; +import org.apache.tez.dag.api.TezConstants; +import org.apache.tez.dag.api.UserPayload; +import org.apache.tez.dag.api.event.VertexStateUpdate; +import org.apache.tez.dag.records.TezTaskAttemptID; +import org.apache.tez.runtime.api.impl.TaskSpec; +import org.apache.tez.serviceplugins.api.ContainerEndReason; +import org.apache.tez.serviceplugins.api.TaskAttemptEndReason; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestTaskCommunicatorManager { + + @Before + @After + public void reset() { + TaskCommManagerForMultipleCommTest.reset(); + } + + @Test(timeout = 5000) + public void testNoTaskCommSpecified() throws IOException { + + AppContext appContext = mock(AppContext.class); + TaskHeartbeatHandler thh = mock(TaskHeartbeatHandler.class); + ContainerHeartbeatHandler chh = mock(ContainerHeartbeatHandler.class); + + try { + new TaskCommManagerForMultipleCommTest(appContext, thh, chh, null); + fail("Initialization should have failed without a TaskComm specified"); + } catch (IllegalArgumentException e) { + + } + + + } + + @Test(timeout = 5000) + public void testCustomTaskCommSpecified() throws IOException { + + AppContext appContext = mock(AppContext.class); + TaskHeartbeatHandler thh = mock(TaskHeartbeatHandler.class); + ContainerHeartbeatHandler chh = mock(ContainerHeartbeatHandler.class); + + String customTaskCommName = "customTaskComm"; + List<NamedEntityDescriptor> taskCommDescriptors = new LinkedList<>(); + ByteBuffer bb = ByteBuffer.allocate(4); + bb.putInt(0, 3); + UserPayload customPayload = UserPayload.create(bb); + taskCommDescriptors.add( + new NamedEntityDescriptor(customTaskCommName, FakeTaskComm.class.getName()) + .setUserPayload(customPayload)); + + TaskCommManagerForMultipleCommTest tcm = + new TaskCommManagerForMultipleCommTest(appContext, thh, chh, taskCommDescriptors); + + try { + tcm.init(new Configuration(false)); + tcm.start(); + + assertEquals(1, tcm.getNumTaskComms()); + assertFalse(tcm.getYarnTaskCommCreated()); + assertFalse(tcm.getUberTaskCommCreated()); + + assertEquals(customTaskCommName, tcm.getTaskCommName(0)); + assertEquals(bb, tcm.getTaskCommContext(0).getInitialUserPayload().getPayload()); + + } finally { + tcm.stop(); + } + } + + @Test(timeout = 5000) + public void testMultipleTaskComms() throws IOException { + + AppContext appContext = mock(AppContext.class); + TaskHeartbeatHandler thh = mock(TaskHeartbeatHandler.class); + ContainerHeartbeatHandler chh = mock(ContainerHeartbeatHandler.class); + Configuration conf = new Configuration(false); + conf.set("testkey", "testvalue"); + UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf); + + String customTaskCommName = "customTaskComm"; + List<NamedEntityDescriptor> taskCommDescriptors = new LinkedList<>(); + ByteBuffer bb = ByteBuffer.allocate(4); + bb.putInt(0, 3); + UserPayload customPayload = UserPayload.create(bb); + taskCommDescriptors.add( + new NamedEntityDescriptor(customTaskCommName, FakeTaskComm.class.getName()) + .setUserPayload(customPayload)); + taskCommDescriptors + .add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null).setUserPayload(defaultPayload)); + + TaskCommManagerForMultipleCommTest tcm = + new TaskCommManagerForMultipleCommTest(appContext, thh, chh, taskCommDescriptors); + + try { + tcm.init(new Configuration(false)); + tcm.start(); + + assertEquals(2, tcm.getNumTaskComms()); + assertTrue(tcm.getYarnTaskCommCreated()); + assertFalse(tcm.getUberTaskCommCreated()); + + assertEquals(customTaskCommName, tcm.getTaskCommName(0)); + assertEquals(bb, tcm.getTaskCommContext(0).getInitialUserPayload().getPayload()); + + assertEquals(TezConstants.getTezYarnServicePluginName(), tcm.getTaskCommName(1)); + Configuration confParsed = TezUtils + .createConfFromUserPayload(tcm.getTaskCommContext(1).getInitialUserPayload()); + assertEquals("testvalue", confParsed.get("testkey")); + } finally { + tcm.stop(); + } + } + + @Test(timeout = 5000) + public void testEventRouting() throws Exception { + + AppContext appContext = mock(AppContext.class, RETURNS_DEEP_STUBS); + NodeId nodeId = NodeId.newInstance("host1", 3131); + when(appContext.getAllContainers().get(any(ContainerId.class)).getContainer().getNodeId()) + .thenReturn(nodeId); + TaskHeartbeatHandler thh = mock(TaskHeartbeatHandler.class); + ContainerHeartbeatHandler chh = mock(ContainerHeartbeatHandler.class); + Configuration conf = new Configuration(false); + conf.set("testkey", "testvalue"); + UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf); + + String customTaskCommName = "customTaskComm"; + List<NamedEntityDescriptor> taskCommDescriptors = new LinkedList<>(); + ByteBuffer bb = ByteBuffer.allocate(4); + bb.putInt(0, 3); + UserPayload customPayload = UserPayload.create(bb); + taskCommDescriptors.add( + new NamedEntityDescriptor(customTaskCommName, FakeTaskComm.class.getName()) + .setUserPayload(customPayload)); + taskCommDescriptors + .add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null).setUserPayload(defaultPayload)); + + TaskCommManagerForMultipleCommTest tcm = + new TaskCommManagerForMultipleCommTest(appContext, thh, chh, taskCommDescriptors); + + try { + tcm.init(new Configuration(false)); + tcm.start(); + + assertEquals(2, tcm.getNumTaskComms()); + assertTrue(tcm.getYarnTaskCommCreated()); + assertFalse(tcm.getUberTaskCommCreated()); + + verify(tcm.getTestTaskComm(0)).initialize(); + verify(tcm.getTestTaskComm(0)).start(); + verify(tcm.getTestTaskComm(1)).initialize(); + verify(tcm.getTestTaskComm(1)).start(); + + + ContainerId containerId1 = mock(ContainerId.class); + tcm.registerRunningContainer(containerId1, 0); + verify(tcm.getTestTaskComm(0)).registerRunningContainer(eq(containerId1), eq("host1"), + eq(3131)); + + ContainerId containerId2 = mock(ContainerId.class); + tcm.registerRunningContainer(containerId2, 1); + verify(tcm.getTestTaskComm(1)).registerRunningContainer(eq(containerId2), eq("host1"), + eq(3131)); + + } finally { + tcm.stop(); + verify(tcm.getTaskCommunicator(0)).shutdown(); + verify(tcm.getTaskCommunicator(1)).shutdown(); + } + } + + + static class TaskCommManagerForMultipleCommTest extends TaskAttemptListenerImpTezDag { + + // All variables setup as static since methods being overridden are invoked by the ContainerLauncherRouter ctor, + // and regular variables will not be initialized at this point. + private static final AtomicInteger numTaskComms = new AtomicInteger(0); + private static final Set<Integer> taskCommIndices = new HashSet<>(); + private static final TaskCommunicator yarnTaskComm = mock(TaskCommunicator.class); + private static final TaskCommunicator uberTaskComm = mock(TaskCommunicator.class); + private static final AtomicBoolean yarnTaskCommCreated = new AtomicBoolean(false); + private static final AtomicBoolean uberTaskCommCreated = new AtomicBoolean(false); + + private static final List<TaskCommunicatorContext> taskCommContexts = + new LinkedList<>(); + private static final List<String> taskCommNames = new LinkedList<>(); + private static final List<TaskCommunicator> testTaskComms = new LinkedList<>(); + + + public static void reset() { + numTaskComms.set(0); + taskCommIndices.clear(); + yarnTaskCommCreated.set(false); + uberTaskCommCreated.set(false); + taskCommContexts.clear(); + taskCommNames.clear(); + testTaskComms.clear(); + } + + public TaskCommManagerForMultipleCommTest(AppContext context, + TaskHeartbeatHandler thh, + ContainerHeartbeatHandler chh, + List<NamedEntityDescriptor> taskCommunicatorDescriptors) { + super(context, thh, chh, taskCommunicatorDescriptors); + } + + @Override + TaskCommunicator createTaskCommunicator(NamedEntityDescriptor taskCommDescriptor, + int taskCommIndex) { + numTaskComms.incrementAndGet(); + boolean added = taskCommIndices.add(taskCommIndex); + assertTrue("Cannot add multiple taskComms with the same index", added); + taskCommNames.add(taskCommDescriptor.getEntityName()); + return super.createTaskCommunicator(taskCommDescriptor, taskCommIndex); + } + + @Override + TaskCommunicator createDefaultTaskCommunicator( + TaskCommunicatorContext taskCommunicatorContext) { + taskCommContexts.add(taskCommunicatorContext); + yarnTaskCommCreated.set(true); + testTaskComms.add(yarnTaskComm); + return yarnTaskComm; + } + + @Override + TaskCommunicator createUberTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext) { + taskCommContexts.add(taskCommunicatorContext); + uberTaskCommCreated.set(true); + testTaskComms.add(uberTaskComm); + return uberTaskComm; + } + + @Override + TaskCommunicator createCustomTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext, + NamedEntityDescriptor taskCommDescriptor) { + taskCommContexts.add(taskCommunicatorContext); + TaskCommunicator spyComm = + spy(super.createCustomTaskCommunicator(taskCommunicatorContext, taskCommDescriptor)); + testTaskComms.add(spyComm); + return spyComm; + } + + public static int getNumTaskComms() { + return numTaskComms.get(); + } + + public static boolean getYarnTaskCommCreated() { + return yarnTaskCommCreated.get(); + } + + public static boolean getUberTaskCommCreated() { + return uberTaskCommCreated.get(); + } + + public static TaskCommunicatorContext getTaskCommContext(int taskCommIndex) { + return taskCommContexts.get(taskCommIndex); + } + + public static String getTaskCommName(int taskCommIndex) { + return taskCommNames.get(taskCommIndex); + } + + public static TaskCommunicator getTestTaskComm(int taskCommIndex) { + return testTaskComms.get(taskCommIndex); + } + } + + public static class FakeTaskComm extends TaskCommunicator { + + public FakeTaskComm(TaskCommunicatorContext taskCommunicatorContext) { + super(taskCommunicatorContext); + } + + @Override + public void registerRunningContainer(ContainerId containerId, String hostname, int port) { + + } + + @Override + public void registerContainerEnd(ContainerId containerId, ContainerEndReason endReason) { + + } + + @Override + public void registerRunningTaskAttempt(ContainerId containerId, TaskSpec taskSpec, + Map<String, LocalResource> additionalResources, + Credentials credentials, boolean credentialsChanged, + int priority) { + + } + + @Override + public void unregisterRunningTaskAttempt(TezTaskAttemptID taskAttemptID, + TaskAttemptEndReason endReason) { + + } + + @Override + public InetSocketAddress getAddress() { + return null; + } + + @Override + public void onVertexStateUpdated(VertexStateUpdate stateUpdate) throws Exception { + + } + + @Override + public void dagComplete(String dagName) { + + } + + @Override + public Object getMetaInfo() { + return null; + } + } +}
