TEZ-2006. Task communication plane needs to be pluggable. (sseth)
Project: http://git-wip-us.apache.org/repos/asf/tez/repo Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/7ab75d82 Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/7ab75d82 Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/7ab75d82 Branch: refs/heads/TEZ-2003 Commit: 7ab75d8263b9e6a894b8574656352dc8ef2d56f0 Parents: e091591 Author: Siddharth Seth <[email protected]> Authored: Thu Feb 12 11:25:45 2015 -0800 Committer: Siddharth Seth <[email protected]> Committed: Thu Aug 6 01:22:32 2015 -0700 ---------------------------------------------------------------------- TEZ-2003-CHANGES.txt | 1 + .../apache/tez/dag/api/TaskCommunicator.java | 54 ++ .../tez/dag/api/TaskCommunicatorContext.java | 48 ++ .../tez/dag/api/TaskHeartbeatRequest.java | 63 +++ .../tez/dag/api/TaskHeartbeatResponse.java | 39 ++ .../java/org/apache/tez/dag/app/AppContext.java | 2 + .../org/apache/tez/dag/app/DAGAppMaster.java | 5 + .../dag/app/TaskAttemptListenerImpTezDag.java | 532 +++++++------------ .../tez/dag/app/TezTaskCommunicatorImpl.java | 476 +++++++++++++++++ .../app/launcher/LocalContainerLauncher.java | 10 +- .../tez/dag/app/rm/container/AMContainer.java | 3 +- .../rm/container/AMContainerEventAssignTA.java | 2 + .../dag/app/rm/container/AMContainerImpl.java | 1 + .../apache/tez/dag/app/MockDAGAppMaster.java | 27 +- .../app/TestTaskAttemptListenerImplTezDag.java | 81 +-- 15 files changed, 968 insertions(+), 376 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/TEZ-2003-CHANGES.txt ---------------------------------------------------------------------- diff --git a/TEZ-2003-CHANGES.txt b/TEZ-2003-CHANGES.txt index 1822fcb..d7e4be5 100644 --- a/TEZ-2003-CHANGES.txt +++ b/TEZ-2003-CHANGES.txt @@ -1,4 +1,5 @@ ALL CHANGES: TEZ-2019. Temporarily allow the scheduler and launcher to be specified via configuration. + TEZ-2006. Task communication plane needs to be pluggable. INCOMPATIBLE CHANGES: http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicator.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicator.java b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicator.java new file mode 100644 index 0000000..97f9c16 --- /dev/null +++ b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicator.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.dag.api; + +import java.net.InetSocketAddress; +import java.util.Map; + +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.service.AbstractService; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.LocalResource; +import org.apache.tez.dag.records.TezTaskAttemptID; +import org.apache.tez.runtime.api.impl.TaskSpec; + +// TODO TEZ-2003 Move this into the tez-api module +public abstract class TaskCommunicator extends AbstractService { + public TaskCommunicator(String name) { + super(name); + } + + // TODO TEZ-2003 Ideally, don't expose YARN containerId; instead expose a Tez specific construct. + // TODO When talking to an external service, this plugin implementer may need access to a host:port + public abstract void registerRunningContainer(ContainerId containerId, String hostname, int port); + + // TODO TEZ-2003 Ideally, don't expose YARN containerId; instead expose a Tez specific construct. + public abstract void registerContainerEnd(ContainerId containerId); + + // TODO TEZ-2003 TaskSpec breakup into a clean interface + // TODO TEZ-2003 Add support for priority + public abstract void registerRunningTaskAttempt(ContainerId containerId, TaskSpec taskSpec, + Map<String, LocalResource> additionalResources, + Credentials credentials, + boolean credentialsChanged); + + // TODO TEZ-2003 Remove reference to TaskAttemptID + public abstract void unregisterRunningTaskAttempt(TezTaskAttemptID taskAttemptID); + + // TODO TEZ-2003 This doesn't necessarily belong here. A server may not start within the AM. + public abstract InetSocketAddress getAddress(); + + // TODO Eventually. Add methods here to support preemption of tasks. +} http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicatorContext.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicatorContext.java b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicatorContext.java new file mode 100644 index 0000000..9b2d889 --- /dev/null +++ b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicatorContext.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.dag.api; + +import java.io.IOException; + +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.tez.dag.records.TezTaskAttemptID; + + +// Do not make calls into this from within a held lock. + +// TODO TEZ-2003 Move this into the tez-api module +public interface TaskCommunicatorContext { + + // TODO TEZ-2003 Add signalling back into this to indicate errors - e.g. Container unregsitered, task no longer running, etc. + + // TODO TEZ-2003 Maybe add book-keeping as a helper library, instead of each impl tracking container to task etc. + + ApplicationAttemptId getApplicationAttemptId(); + Credentials getCredentials(); + + // TODO TEZ-2003 Move to vertex, taskIndex, version + boolean canCommit(TezTaskAttemptID taskAttemptId) throws IOException; + + TaskHeartbeatResponse heartbeat(TaskHeartbeatRequest request) throws IOException, TezException; + + boolean isKnownContainer(ContainerId containerId); + + // TODO TEZ-2003 Move to vertex, taskIndex, version + void taskStartedRemotely(TezTaskAttemptID taskAttemptID, ContainerId containerId); + + // TODO Eventually Add methods to report availability stats to the scheduler. +} http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatRequest.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatRequest.java b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatRequest.java new file mode 100644 index 0000000..f6bc8f0 --- /dev/null +++ b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatRequest.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.dag.api; + +import java.util.List; + +import org.apache.tez.dag.records.TezTaskAttemptID; +import org.apache.tez.runtime.api.impl.TezEvent; + +// TODO TEZ-2003 Move this into the tez-api module +public class TaskHeartbeatRequest { + + // TODO TEZ-2003 Ideally containerIdentifier should not be part of the request. + // Replace with a task lookup - vertex name + task index + private final String containerIdentifier; + // TODO TEZ-2003 Get rid of the task attemptId reference if possible + private final TezTaskAttemptID taskAttemptId; + private final List<TezEvent> events; + private final int startIndex; + private final int maxEvents; + + + public TaskHeartbeatRequest(String containerIdentifier, TezTaskAttemptID taskAttemptId, List<TezEvent> events, int startIndex, + int maxEvents) { + this.containerIdentifier = containerIdentifier; + this.taskAttemptId = taskAttemptId; + this.events = events; + this.startIndex = startIndex; + this.maxEvents = maxEvents; + } + + public String getContainerIdentifier() { + return containerIdentifier; + } + + public TezTaskAttemptID getTaskAttemptId() { + return taskAttemptId; + } + + public List<TezEvent> getEvents() { + return events; + } + + public int getStartIndex() { + return startIndex; + } + + public int getMaxEvents() { + return maxEvents; + } +} http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatResponse.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatResponse.java b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatResponse.java new file mode 100644 index 0000000..c82a743 --- /dev/null +++ b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatResponse.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.dag.api; + +import java.util.List; + +import org.apache.tez.runtime.api.impl.TezEvent; + +// TODO TEZ-2003 Move this into the tez-api module +public class TaskHeartbeatResponse { + + private final boolean shouldDie; + private List<TezEvent> events; + + public TaskHeartbeatResponse(boolean shouldDie, List<TezEvent> events) { + this.shouldDie = shouldDie; + this.events = events; + } + + public boolean isShouldDie() { + return shouldDie; + } + + public List<TezEvent> getEvents() { + return events; + } +} http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/tez-dag/src/main/java/org/apache/tez/dag/app/AppContext.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/AppContext.java b/tez-dag/src/main/java/org/apache/tez/dag/app/AppContext.java index e909d80..bf3e318 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/AppContext.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/AppContext.java @@ -24,6 +24,7 @@ import java.util.Set; import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; +import org.apache.hadoop.security.Credentials; import org.apache.hadoop.yarn.api.records.ApplicationAccessType; import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; import org.apache.hadoop.yarn.api.records.ApplicationId; @@ -112,4 +113,5 @@ public interface AppContext { /** Whether the AM is in the process of shutting down/completing */ boolean isAMInCompletionState(); + Credentials getAppCredentials(); } http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java b/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java index a45731f..dbcbdd0 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java @@ -1526,6 +1526,11 @@ public class DAGAppMaster extends AbstractService { } @Override + public Credentials getAppCredentials() { + return amCredentials; + } + + @Override public Map<ApplicationAccessType, String> getApplicationACLs() { if (getServiceState() != STATE.STARTED) { throw new TezUncheckedException( http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java index 15cb801..ff50907 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java @@ -18,15 +18,14 @@ package org.apache.tez.dag.app; import java.io.IOException; -import java.net.InetAddress; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; import java.net.InetSocketAddress; import java.net.URISyntaxException; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.Map.Entry; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -37,222 +36,212 @@ import org.apache.tez.runtime.api.events.TaskStatusUpdateEvent; import org.apache.tez.runtime.api.impl.EventType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.CommonConfigurationKeysPublic; -import org.apache.hadoop.ipc.ProtocolSignature; -import org.apache.hadoop.ipc.RPC; -import org.apache.hadoop.ipc.Server; -import org.apache.hadoop.net.NetUtils; -import org.apache.hadoop.security.authorize.PolicyProvider; +import org.apache.hadoop.security.Credentials; import org.apache.hadoop.service.AbstractService; +import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; +import org.apache.hadoop.yarn.api.records.NodeId; +import org.apache.tez.common.ReflectionUtils; +import org.apache.tez.dag.api.TaskCommunicator; +import org.apache.tez.dag.api.TaskCommunicatorContext; +import org.apache.tez.dag.api.TaskHeartbeatResponse; +import org.apache.tez.dag.api.TezConfiguration; import org.apache.tez.dag.api.TezException; import org.apache.tez.dag.api.TezUncheckedException; import org.apache.hadoop.yarn.api.records.ContainerId; -import org.apache.hadoop.yarn.api.records.LocalResource; import org.apache.hadoop.yarn.util.ConverterUtils; -import org.apache.tez.common.ContainerContext; -import org.apache.tez.common.ContainerTask; -import org.apache.tez.common.TezConverterUtils; -import org.apache.tez.common.TezLocalResource; -import org.apache.tez.common.TezTaskUmbilicalProtocol; -import org.apache.tez.dag.api.TezConfiguration; +import org.apache.tez.dag.api.TaskHeartbeatRequest; import org.apache.tez.dag.app.dag.DAG; import org.apache.tez.dag.app.dag.Task; import org.apache.tez.dag.app.dag.event.TaskAttemptEventStartedRemotely; import org.apache.tez.dag.app.dag.event.VertexEventRouteEvent; +import org.apache.tez.dag.app.rm.TaskSchedulerService; import org.apache.tez.dag.app.rm.container.AMContainerTask; -import org.apache.tez.dag.app.security.authorize.TezAMPolicyProvider; import org.apache.tez.dag.records.TezTaskAttemptID; import org.apache.tez.dag.records.TezVertexID; import org.apache.tez.runtime.api.impl.TezEvent; -import org.apache.tez.runtime.api.impl.TezHeartbeatRequest; -import org.apache.tez.runtime.api.impl.TezHeartbeatResponse; import org.apache.tez.common.security.JobTokenSecretManager; -import com.google.common.collect.Maps; @SuppressWarnings("unchecked") [email protected] public class TaskAttemptListenerImpTezDag extends AbstractService implements - TezTaskUmbilicalProtocol, TaskAttemptListener { - - private static final ContainerTask TASK_FOR_INVALID_JVM = new ContainerTask( - null, true, null, null, false); + TaskAttemptListener, TaskCommunicatorContext { private static final Logger LOG = LoggerFactory .getLogger(TaskAttemptListenerImpTezDag.class); private final AppContext context; + private TaskCommunicator taskCommunicator; protected final TaskHeartbeatHandler taskHeartbeatHandler; protected final ContainerHeartbeatHandler containerHeartbeatHandler; - private final JobTokenSecretManager jobTokenSecretManager; - private InetSocketAddress address; - private Server server; - - static class ContainerInfo { - ContainerInfo() { - this.lastReponse = null; - this.lastRequestId = 0; - this.amContainerTask = null; - this.taskPulled = false; + + private final TaskHeartbeatResponse RESPONSE_SHOULD_DIE = new TaskHeartbeatResponse(true, null); + + private final ConcurrentMap<TezTaskAttemptID, ContainerId> registeredAttempts = + new ConcurrentHashMap<TezTaskAttemptID, ContainerId>(); + private final ConcurrentMap<ContainerId, ContainerInfo> registeredContainers = + new ConcurrentHashMap<ContainerId, ContainerInfo>(); + + // Defined primarily to work around ConcurrentMaps not accepting null values + private static final class ContainerInfo { + TezTaskAttemptID taskAttemptId; + ContainerInfo(TezTaskAttemptID taskAttemptId) { + this.taskAttemptId = taskAttemptId; } - long lastRequestId; - TezHeartbeatResponse lastReponse; - AMContainerTask amContainerTask; - boolean taskPulled; } - private ConcurrentMap<TezTaskAttemptID, ContainerId> attemptToInfoMap = - new ConcurrentHashMap<TezTaskAttemptID, ContainerId>(); + private static final ContainerInfo NULL_CONTAINER_INFO = new ContainerInfo(null); - private ConcurrentHashMap<ContainerId, ContainerInfo> registeredContainers = - new ConcurrentHashMap<ContainerId, ContainerInfo>(); public TaskAttemptListenerImpTezDag(AppContext context, - TaskHeartbeatHandler thh, ContainerHeartbeatHandler chh, - JobTokenSecretManager jobTokenSecretManager) { + TaskHeartbeatHandler thh, ContainerHeartbeatHandler chh, + // TODO TEZ-2003 pre-merge. Remove reference to JobTokenSecretManager. + JobTokenSecretManager jobTokenSecretManager) { super(TaskAttemptListenerImpTezDag.class.getName()); this.context = context; - this.jobTokenSecretManager = jobTokenSecretManager; this.taskHeartbeatHandler = thh; this.containerHeartbeatHandler = chh; + this.taskCommunicator = new TezTaskCommunicatorImpl(this); } @Override - public void serviceStart() { - startRpcServer(); - } - - protected void startRpcServer() { - Configuration conf = getConfig(); - if (!conf.getBoolean(TezConfiguration.TEZ_LOCAL_MODE, TezConfiguration.TEZ_LOCAL_MODE_DEFAULT)) { - try { - server = new RPC.Builder(conf) - .setProtocol(TezTaskUmbilicalProtocol.class) - .setBindAddress("0.0.0.0") - .setPort(0) - .setInstance(this) - .setNumHandlers( - conf.getInt(TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT, - TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT_DEFAULT)) - .setPortRangeConfig(TezConfiguration.TEZ_AM_TASK_AM_PORT_RANGE) - .setSecretManager(jobTokenSecretManager).build(); - - // Enable service authorization? - if (conf.getBoolean( - CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION, - false)) { - refreshServiceAcls(conf, new TezAMPolicyProvider()); - } - - server.start(); - InetSocketAddress serverBindAddress = NetUtils.getConnectAddress(server); - this.address = NetUtils.createSocketAddrForHost( - serverBindAddress.getAddress().getCanonicalHostName(), - serverBindAddress.getPort()); - - LOG.info("Instantiated TaskAttemptListener RPC at " + this.address); - } catch (IOException e) { - throw new TezUncheckedException(e); - } + public void serviceInit(Configuration conf) { + String taskCommClassName = conf.get(TezConfiguration.TEZ_AM_TASK_COMMUNICATOR_CLASS); + if (taskCommClassName == null) { + LOG.info("Using Default Task Communicator"); + this.taskCommunicator = new TezTaskCommunicatorImpl(this); } else { + LOG.info("Using TaskCommunicator: " + taskCommClassName); + Class<? extends TaskCommunicator> taskCommClazz = (Class<? extends TaskCommunicator>) ReflectionUtils + .getClazz(taskCommClassName); try { - this.address = new InetSocketAddress(InetAddress.getLocalHost(), 0); - } catch (UnknownHostException e) { + Constructor<? extends TaskCommunicator> ctor = taskCommClazz.getConstructor(TaskCommunicatorContext.class); + ctor.setAccessible(true); + this.taskCommunicator = ctor.newInstance(this); + } catch (NoSuchMethodException e) { + throw new TezUncheckedException(e); + } catch (InvocationTargetException e) { + throw new TezUncheckedException(e); + } catch (InstantiationException e) { + throw new TezUncheckedException(e); + } catch (IllegalAccessException e) { throw new TezUncheckedException(e); - } - if (LOG.isDebugEnabled()) { - LOG.debug("Not starting TaskAttemptListener RPC in LocalMode"); } } } - void refreshServiceAcls(Configuration configuration, - PolicyProvider policyProvider) { - this.server.refreshServiceAcl(configuration, policyProvider); + @Override + public void serviceStart() { + taskCommunicator.init(getConfig()); + taskCommunicator.start(); } @Override public void serviceStop() { - stopRpcServer(); - } - - protected void stopRpcServer() { - if (server != null) { - server.stop(); + if (taskCommunicator != null) { + taskCommunicator.stop(); + taskCommunicator = null; } } - public InetSocketAddress getAddress() { - return address; - } - @Override - public long getProtocolVersion(String protocol, long clientVersion) - throws IOException { - return versionID; + public ApplicationAttemptId getApplicationAttemptId() { + return context.getApplicationAttemptId(); } @Override - public ProtocolSignature getProtocolSignature(String protocol, - long clientVersion, int clientMethodsHash) throws IOException { - return ProtocolSignature.getProtocolSignature(this, protocol, - clientVersion, clientMethodsHash); + public Credentials getCredentials() { + return context.getAppCredentials(); } @Override - public ContainerTask getTask(ContainerContext containerContext) - throws IOException { + public TaskHeartbeatResponse heartbeat(TaskHeartbeatRequest request) + throws IOException, TezException { + ContainerId containerId = ConverterUtils.toContainerId(request + .getContainerIdentifier()); + if (LOG.isDebugEnabled()) { + LOG.debug("Received heartbeat from container" + + ", request=" + request); + } - ContainerTask task = null; + if (!registeredContainers.containsKey(containerId)) { + LOG.warn("Received task heartbeat from unknown container with id: " + containerId + + ", asking it to die"); + return RESPONSE_SHOULD_DIE; + } - if (containerContext == null || containerContext.getContainerIdentifier() == null) { - LOG.info("Invalid task request with an empty containerContext or containerId"); - task = TASK_FOR_INVALID_JVM; - } else { - ContainerId containerId = ConverterUtils.toContainerId(containerContext - .getContainerIdentifier()); + // A heartbeat can come in anytime. The AM may have made a decision to kill a running task/container + // meanwhile. If the decision is processed through the pipeline before the heartbeat is processed, + // the heartbeat will be dropped. Otherwise the heartbeat will be processed - and the system + // know how to handle this - via FailedInputEvents for example (relevant only if the heartbeat has events). + // So - avoiding synchronization. + + pingContainerHeartbeatHandler(containerId); + List<TezEvent> outEvents = null; + TezTaskAttemptID taskAttemptID = request.getTaskAttemptId(); + if (taskAttemptID != null) { + ContainerId containerIdFromMap = registeredAttempts.get(taskAttemptID); + if (containerIdFromMap == null || !containerIdFromMap.equals(containerId)) { + // This can happen when a task heartbeats. Meanwhile the container is unregistered. + // The information will eventually make it through to the plugin via a corresponding unregister. + // There's a race in that case between the unregister making it through, and this method returning. + // TODO TEZ-2003. An exception back is likely a better approach than sending a shouldDie = true, + // so that the plugin can handle the scenario. Alternately augment the response with error codes. + // Error codes would be better than exceptions. + LOG.info("Attempt: " + taskAttemptID + " is not recognized for heartbeats"); + return RESPONSE_SHOULD_DIE; + } + + List<TezEvent> inEvents = request.getEvents(); if (LOG.isDebugEnabled()) { - LOG.debug("Container with id: " + containerId + " asked for a task"); + LOG.debug("Ping from " + taskAttemptID.toString() + + " events: " + (inEvents != null ? inEvents.size() : -1)); } - if (!registeredContainers.containsKey(containerId)) { - if(context.getAllContainers().get(containerId) == null) { - LOG.info("Container with id: " + containerId - + " is invalid and will be killed"); - } else { - LOG.info("Container with id: " + containerId - + " is valid, but no longer registered, and will be killed"); - } - task = TASK_FOR_INVALID_JVM; - } else { - pingContainerHeartbeatHandler(containerId); - task = getContainerTask(containerId); - if (task == null) { - if (LOG.isDebugEnabled()) { - LOG.debug("No task current assigned to Container with id: " + containerId); - } - } else if (task == TASK_FOR_INVALID_JVM) { - LOG.info("Container with id: " + containerId - + " is valid, but no longer registered, and will be killed. Race condition."); + + List<TezEvent> otherEvents = new ArrayList<TezEvent>(); + for (TezEvent tezEvent : ListUtils.emptyIfNull(inEvents)) { + final EventType eventType = tezEvent.getEventType(); + if (eventType == EventType.TASK_STATUS_UPDATE_EVENT || + eventType == EventType.TASK_ATTEMPT_COMPLETED_EVENT) { + context.getEventHandler() + .handle(getTaskAttemptEventFromTezEvent(taskAttemptID, tezEvent)); } else { - context.getEventHandler().handle( - new TaskAttemptEventStartedRemotely(task.getTaskSpec() - .getTaskAttemptID(), containerId, context - .getApplicationACLs())); - LOG.info("Container with id: " + containerId + " given task: " - + task.getTaskSpec().getTaskAttemptID()); + otherEvents.add(tezEvent); } } + if(!otherEvents.isEmpty()) { + TezVertexID vertexId = taskAttemptID.getTaskID().getVertexID(); + context.getEventHandler().handle( + new VertexEventRouteEvent(vertexId, Collections.unmodifiableList(otherEvents))); + } + taskHeartbeatHandler.pinged(taskAttemptID); + outEvents = context + .getCurrentDAG() + .getVertex(taskAttemptID.getTaskID().getVertexID()) + .getTask(taskAttemptID.getTaskID()) + .getTaskAttemptTezEvents(taskAttemptID, request.getStartIndex(), + request.getMaxEvents()); } - if (LOG.isDebugEnabled()) { - LOG.debug("getTask returning task: " + task); - } - return task; + return new TaskHeartbeatResponse(false, outEvents); + } + + @Override + public boolean isKnownContainer(ContainerId containerId) { + return context.getAllContainers().get(containerId) != null; + } + + @Override + public void taskStartedRemotely(TezTaskAttemptID taskAttemptID, ContainerId containerId) { + context.getEventHandler().handle(new TaskAttemptEventStartedRemotely(taskAttemptID, containerId, null)); + pingContainerHeartbeatHandler(containerId); } /** * Child checking whether it can commit. - * + * <p/> * <br/> * Repeatedly polls the ApplicationMaster whether it * {@link Task#canCommit(TezTaskAttemptID)} This is * a legacy from the @@ -275,25 +264,12 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements } @Override - public void unregisterTaskAttempt(TezTaskAttemptID attemptId) { - ContainerId containerId = attemptToInfoMap.get(attemptId); - if(containerId == null) { - LOG.warn("Unregister task attempt: " + attemptId + " from unknown container"); - return; - } - ContainerInfo containerInfo = registeredContainers.get(containerId); - if(containerInfo == null) { - LOG.warn("Unregister task attempt: " + attemptId + - " from non-registered container: " + containerId); - return; - } - synchronized (containerInfo) { - containerInfo.amContainerTask = null; - attemptToInfoMap.remove(attemptId); - } - + public InetSocketAddress getAddress() { + return taskCommunicator.getAddress(); } + // The TaskAttemptListener register / unregister methods in this class are not thread safe. + // The Tez framework should not invoke these methods from multiple threads. @Override public void dagComplete(DAG dag) { // TODO TEZ-2335. Cleanup TaskHeartbeat handler structures. @@ -313,50 +289,82 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements @Override public void registerRunningContainer(ContainerId containerId) { if (LOG.isDebugEnabled()) { - LOG.debug("ContainerId: " + containerId - + " registered with TaskAttemptListener"); + LOG.debug("ContainerId: " + containerId + " registered with TaskAttemptListener"); } - ContainerInfo oldInfo = registeredContainers.put(containerId, new ContainerInfo()); - if(oldInfo != null) { + ContainerInfo oldInfo = registeredContainers.put(containerId, NULL_CONTAINER_INFO); + if (oldInfo != null) { throw new TezUncheckedException( "Multiple registrations for containerId: " + containerId); } + NodeId nodeId = context.getAllContainers().get(containerId).getContainer().getNodeId(); + taskCommunicator.registerRunningContainer(containerId, nodeId.getHost(), nodeId.getPort()); + } + + @Override + public void unregisterRunningContainer(ContainerId containerId) { + if (LOG.isDebugEnabled()) { + LOG.debug("Unregistering Container from TaskAttemptListener: " + containerId); + } + ContainerInfo containerInfo = registeredContainers.remove(containerId); + if (containerInfo.taskAttemptId != null) { + registeredAttempts.remove(containerInfo.taskAttemptId); + } + taskCommunicator.registerContainerEnd(containerId); } @Override public void registerTaskAttempt(AMContainerTask amContainerTask, - ContainerId containerId) { + ContainerId containerId) { ContainerInfo containerInfo = registeredContainers.get(containerId); - if(containerInfo == null) { + if (containerInfo == null) { throw new TezUncheckedException("Registering task attempt: " + amContainerTask.getTask().getTaskAttemptID() + " to unknown container: " + containerId); } - synchronized (containerInfo) { - if(containerInfo.amContainerTask != null) { - throw new TezUncheckedException("Registering task attempt: " - + amContainerTask.getTask().getTaskAttemptID() + " to container: " + containerId - + " with existing assignment to: " + containerInfo.amContainerTask.getTask().getTaskAttemptID()); - } - containerInfo.amContainerTask = amContainerTask; - containerInfo.taskPulled = false; - - ContainerId containerIdFromMap = - attemptToInfoMap.put(amContainerTask.getTask().getTaskAttemptID(), containerId); - if(containerIdFromMap != null) { - throw new TezUncheckedException("Registering task attempt: " - + amContainerTask.getTask().getTaskAttemptID() + " to container: " + containerId - + " when already assigned to: " + containerIdFromMap); - } + if (containerInfo.taskAttemptId != null) { + throw new TezUncheckedException("Registering task attempt: " + + amContainerTask.getTask().getTaskAttemptID() + " to container: " + containerId + + " with existing assignment to: " + + containerInfo.taskAttemptId); } + + if (containerInfo.taskAttemptId != null) { + throw new TezUncheckedException("Registering task attempt: " + + amContainerTask.getTask().getTaskAttemptID() + " to container: " + containerId + + " with existing assignment to: " + + containerInfo.taskAttemptId); + } + + // Explicitly putting in a new entry so that synchronization is not required on the existing element in the map. + registeredContainers.put(containerId, new ContainerInfo(amContainerTask.getTask().getTaskAttemptID())); + + ContainerId containerIdFromMap = registeredAttempts.put( + amContainerTask.getTask().getTaskAttemptID(), containerId); + if (containerIdFromMap != null) { + throw new TezUncheckedException("Registering task attempt: " + + amContainerTask.getTask().getTaskAttemptID() + " to container: " + containerId + + " when already assigned to: " + containerIdFromMap); + } + taskCommunicator.registerRunningTaskAttempt(containerId, amContainerTask.getTask(), + amContainerTask.getAdditionalResources(), amContainerTask.getCredentials(), + amContainerTask.haveCredentialsChanged()); } @Override - public void unregisterRunningContainer(ContainerId containerId) { - if (LOG.isDebugEnabled()) { - LOG.debug("Unregistering Container from TaskAttemptListener: " - + containerId); + public void unregisterTaskAttempt(TezTaskAttemptID attemptId) { + ContainerId containerId = registeredAttempts.remove(attemptId); + if (containerId == null) { + LOG.warn("Unregister task attempt: " + attemptId + " from unknown container"); + return; } - registeredContainers.remove(containerId); + ContainerInfo containerInfo = registeredContainers.get(containerId); + if (containerInfo == null) { + LOG.warn("Unregister task attempt: " + attemptId + + " from non-registered container: " + containerId); + return; + } + // Explicitly putting in a new entry so that synchronization is not required on the existing element in the map. + registeredContainers.put(containerId, NULL_CONTAINER_INFO); + taskCommunicator.unregisterRunningTaskAttempt(attemptId); } private void pingContainerHeartbeatHandler(ContainerId containerId) { @@ -364,7 +372,7 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements } private void pingContainerHeartbeatHandler(TezTaskAttemptID taskAttemptId) { - ContainerId containerId = attemptToInfoMap.get(taskAttemptId); + ContainerId containerId = registeredAttempts.get(taskAttemptId); if (containerId != null) { containerHeartbeatHandler.pinged(containerId); } else { @@ -373,146 +381,8 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements } } - @Override - public TezHeartbeatResponse heartbeat(TezHeartbeatRequest request) - throws IOException, TezException { - ContainerId containerId = ConverterUtils.toContainerId(request - .getContainerIdentifier()); - long requestId = request.getRequestId(); - if (LOG.isDebugEnabled()) { - LOG.debug("Received heartbeat from container" - + ", request=" + request); - } - - ContainerInfo containerInfo = registeredContainers.get(containerId); - if(containerInfo == null) { - LOG.warn("Received task heartbeat from unknown container with id: " + containerId + - ", asking it to die"); - TezHeartbeatResponse response = new TezHeartbeatResponse(); - response.setLastRequestId(requestId); - response.setShouldDie(); - return response; - } - - synchronized (containerInfo) { - pingContainerHeartbeatHandler(containerId); - - if(containerInfo.lastRequestId == requestId) { - LOG.warn("Old sequenceId received: " + requestId - + ", Re-sending last response to client"); - return containerInfo.lastReponse; - } - - TezHeartbeatResponse response = new TezHeartbeatResponse(); - response.setLastRequestId(requestId); - - TezTaskAttemptID taskAttemptID = request.getCurrentTaskAttemptID(); - if (taskAttemptID != null) { - ContainerId containerIdFromMap = attemptToInfoMap.get(taskAttemptID); - if(containerIdFromMap == null || !containerIdFromMap.equals(containerId)) { - throw new TezException("Attempt " + taskAttemptID - + " is not recognized for heartbeat"); - } - - if(containerInfo.lastRequestId+1 != requestId) { - throw new TezException("Container " + containerId - + " has invalid request id. Expected: " - + containerInfo.lastRequestId+1 - + " and actual: " + requestId); - } - - List<TezEvent> inEvents = request.getEvents(); - if (LOG.isDebugEnabled()) { - LOG.debug("Ping from " + taskAttemptID.toString() + - " events: " + (inEvents != null? inEvents.size() : -1)); - } - - long currTime = context.getClock().getTime(); - List<TezEvent> otherEvents = new ArrayList<TezEvent>(); - // route TASK_STATUS_UPDATE_EVENT directly to TaskAttempt and route other events - // (DATA_MOVEMENT_EVENT, TASK_ATTEMPT_COMPLETED_EVENT, TASK_ATTEMPT_FAILED_EVENT) - // to VertexImpl to ensure the events ordering - // 1. DataMovementEvent is logged as RecoveryEvent before TaskAttemptFinishedEvent - // 2. TaskStatusEvent is handled before TaskAttemptFinishedEvent - for (TezEvent tezEvent : ListUtils.emptyIfNull(inEvents)) { - // for now, set the event time on the AM when it is received. - // this avoids any time disparity between machines. - tezEvent.setEventReceivedTime(currTime); - final EventType eventType = tezEvent.getEventType(); - if (eventType == EventType.TASK_STATUS_UPDATE_EVENT) { - TaskAttemptEvent taskAttemptEvent = new TaskAttemptEventStatusUpdate(taskAttemptID, - (TaskStatusUpdateEvent) tezEvent.getEvent()); - context.getEventHandler().handle(taskAttemptEvent); - } else { - otherEvents.add(tezEvent); - } - } - if(!otherEvents.isEmpty()) { - TezVertexID vertexId = taskAttemptID.getTaskID().getVertexID(); - context.getEventHandler().handle( - new VertexEventRouteEvent(vertexId, Collections.unmodifiableList(otherEvents))); - } - taskHeartbeatHandler.pinged(taskAttemptID); - TaskAttemptEventInfo eventInfo = context - .getCurrentDAG() - .getVertex(taskAttemptID.getTaskID().getVertexID()) - .getTaskAttemptTezEvents(taskAttemptID, request.getStartIndex(), - request.getPreRoutedStartIndex(), request.getMaxEvents()); - response.setEvents(eventInfo.getEvents()); - response.setNextFromEventId(eventInfo.getNextFromEventId()); - response.setNextPreRoutedEventId(eventInfo.getNextPreRoutedFromEventId()); - } - containerInfo.lastRequestId = requestId; - containerInfo.lastReponse = response; - return response; - } - } - private Map<String, TezLocalResource> convertLocalResourceMap(Map<String, LocalResource> ylrs) - throws IOException { - Map<String, TezLocalResource> tlrs = Maps.newHashMap(); - if (ylrs != null) { - for (Entry<String, LocalResource> ylrEntry : ylrs.entrySet()) { - TezLocalResource tlr; - try { - tlr = TezConverterUtils.convertYarnLocalResourceToTez(ylrEntry.getValue()); - } catch (URISyntaxException e) { - throw new IOException(e); - } - tlrs.put(ylrEntry.getKey(), tlr); - } - } - return tlrs; - } - - private ContainerTask getContainerTask(ContainerId containerId) throws IOException { - ContainerTask containerTask = null; - ContainerInfo containerInfo = registeredContainers.get(containerId); - if (containerInfo == null) { - // This can happen if an unregisterTask comes in after we've done the initial checks for - // registered containers. (Race between getTask from the container, and a potential STOP_CONTAINER - // from somewhere within the AM) - // Implies that an un-registration has taken place and the container needs to be asked to die. - LOG.info("Container with id: " + containerId - + " is valid, but no longer registered, and will be killed"); - containerTask = TASK_FOR_INVALID_JVM; - } else { - synchronized (containerInfo) { - if (containerInfo.amContainerTask != null) { - if (!containerInfo.taskPulled) { - containerInfo.taskPulled = true; - AMContainerTask amContainerTask = containerInfo.amContainerTask; - containerTask = new ContainerTask(amContainerTask.getTask(), false, - convertLocalResourceMap(amContainerTask.getAdditionalResources()), - amContainerTask.getCredentials(), amContainerTask.haveCredentialsChanged()); - } else { - containerTask = null; - } - } else { - containerTask = null; - } - } - } - return containerTask; + public TaskCommunicator getTaskCommunicator() { + return taskCommunicator; } } http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/tez-dag/src/main/java/org/apache/tez/dag/app/TezTaskCommunicatorImpl.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/TezTaskCommunicatorImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/TezTaskCommunicatorImpl.java new file mode 100644 index 0000000..e40f79c --- /dev/null +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/TezTaskCommunicatorImpl.java @@ -0,0 +1,476 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.dag.app; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.URISyntaxException; +import java.net.UnknownHostException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Maps; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.CommonConfigurationKeysPublic; +import org.apache.hadoop.ipc.ProtocolSignature; +import org.apache.hadoop.ipc.RPC; +import org.apache.hadoop.ipc.Server; +import org.apache.hadoop.net.NetUtils; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.authorize.PolicyProvider; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.LocalResource; +import org.apache.hadoop.yarn.util.ConverterUtils; +import org.apache.tez.common.*; +import org.apache.tez.common.ContainerContext; +import org.apache.tez.common.security.JobTokenIdentifier; +import org.apache.tez.common.security.JobTokenSecretManager; +import org.apache.tez.common.security.TokenCache; +import org.apache.tez.dag.api.TaskCommunicator; +import org.apache.tez.dag.api.TaskCommunicatorContext; +import org.apache.tez.dag.api.TaskHeartbeatRequest; +import org.apache.tez.dag.api.TaskHeartbeatResponse; +import org.apache.tez.dag.api.TezConfiguration; +import org.apache.tez.dag.api.TezException; +import org.apache.tez.dag.api.TezUncheckedException; +import org.apache.tez.dag.app.security.authorize.TezAMPolicyProvider; +import org.apache.tez.dag.records.TezTaskAttemptID; +import org.apache.tez.runtime.api.impl.TaskSpec; +import org.apache.tez.runtime.api.impl.TezHeartbeatRequest; +import org.apache.tez.runtime.api.impl.TezHeartbeatResponse; + [email protected] +public class TezTaskCommunicatorImpl extends TaskCommunicator { + + private static final Log LOG = LogFactory.getLog(TezTaskCommunicatorImpl.class); + + private static final ContainerTask TASK_FOR_INVALID_JVM = new ContainerTask( + null, true, null, null, false); + + private final TaskCommunicatorContext taskCommunicatorContext; + + private final ConcurrentMap<ContainerId, ContainerInfo> registeredContainers = + new ConcurrentHashMap<ContainerId, ContainerInfo>(); + private final ConcurrentMap<TaskAttempt, ContainerId> attemptToContainerMap = + new ConcurrentHashMap<TaskAttempt, ContainerId>(); + + private final TezTaskUmbilicalProtocol taskUmbilical; + private InetSocketAddress address; + private Server server; + + private static final class ContainerInfo { + + ContainerInfo(ContainerId containerId) { + this.containerId = containerId; + } + + ContainerId containerId; + TezHeartbeatResponse lastResponse = null; + TaskSpec taskSpec = null; + long lastRequestId = 0; + Map<String, LocalResource> additionalLRs = null; + Credentials credentials = null; + boolean credentialsChanged = false; + boolean taskPulled = false; + + void reset() { + taskSpec = null; + additionalLRs = null; + credentials = null; + credentialsChanged = false; + taskPulled = false; + } + } + + + + /** + * Construct the service. + */ + public TezTaskCommunicatorImpl(TaskCommunicatorContext taskCommunicatorContext) { + super(TezTaskCommunicatorImpl.class.getName()); + this.taskCommunicatorContext = taskCommunicatorContext; + this.taskUmbilical = new TezTaskUmbilicalProtocolImpl(); + } + + + @Override + public void serviceStart() { + + startRpcServer(); + } + + @Override + public void serviceStop() { + stopRpcServer(); + } + + protected void startRpcServer() { + Configuration conf = getConfig(); + if (!conf.getBoolean(TezConfiguration.TEZ_LOCAL_MODE, TezConfiguration.TEZ_LOCAL_MODE_DEFAULT)) { + try { + JobTokenSecretManager jobTokenSecretManager = + new JobTokenSecretManager(); + Token<JobTokenIdentifier> sessionToken = TokenCache.getSessionToken(taskCommunicatorContext.getCredentials()); + jobTokenSecretManager.addTokenForJob( + taskCommunicatorContext.getApplicationAttemptId().getApplicationId().toString(), sessionToken); + + server = new RPC.Builder(conf) + .setProtocol(TezTaskUmbilicalProtocol.class) + .setBindAddress("0.0.0.0") + .setPort(0) + .setInstance(taskUmbilical) + .setNumHandlers( + conf.getInt(TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT, + TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT_DEFAULT)) + .setPortRangeConfig(TezConfiguration.TEZ_AM_TASK_AM_PORT_RANGE) + .setSecretManager(jobTokenSecretManager).build(); + + // Enable service authorization? + if (conf.getBoolean( + CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION, + false)) { + refreshServiceAcls(conf, new TezAMPolicyProvider()); + } + + server.start(); + this.address = NetUtils.getConnectAddress(server); + LOG.info("Instantiated TezTaskCommunicator RPC at " + this.address); + } catch (IOException e) { + throw new TezUncheckedException(e); + } + } else { + try { + this.address = new InetSocketAddress(InetAddress.getLocalHost(), 0); + } catch (UnknownHostException e) { + throw new TezUncheckedException(e); + } + if (LOG.isDebugEnabled()) { + LOG.debug("Not starting TaskAttemptListener RPC in LocalMode"); + } + } + } + + protected void stopRpcServer() { + if (server != null) { + server.stop(); + server = null; + } + } + + private void refreshServiceAcls(Configuration configuration, + PolicyProvider policyProvider) { + this.server.refreshServiceAcl(configuration, policyProvider); + } + + @Override + public void registerRunningContainer(ContainerId containerId, String host, int port) { + ContainerInfo oldInfo = registeredContainers.putIfAbsent(containerId, new ContainerInfo(containerId)); + if (oldInfo != null) { + throw new TezUncheckedException("Multiple registrations for containerId: " + containerId); + } + } + + @Override + public void registerContainerEnd(ContainerId containerId) { + ContainerInfo containerInfo = registeredContainers.remove(containerId); + if (containerInfo != null) { + synchronized(containerInfo) { + if (containerInfo.taskSpec != null && containerInfo.taskSpec.getTaskAttemptID() != null) { + attemptToContainerMap.remove(containerInfo.taskSpec.getTaskAttemptID()); + } + } + } + } + + @Override + public void registerRunningTaskAttempt(ContainerId containerId, TaskSpec taskSpec, + Map<String, LocalResource> additionalResources, + Credentials credentials, boolean credentialsChanged) { + + ContainerInfo containerInfo = registeredContainers.get(containerId); + Preconditions.checkNotNull(containerInfo, + "Cannot register task attempt: " + taskSpec.getTaskAttemptID() + " to unknown container: " + + containerId); + synchronized (containerInfo) { + if (containerInfo.taskSpec != null) { + throw new TezUncheckedException( + "Cannot register task: " + taskSpec.getTaskAttemptID() + " to container: " + + containerId + " , with pre-existing assignment: " + + containerInfo.taskSpec.getTaskAttemptID()); + } + containerInfo.taskSpec = taskSpec; + containerInfo.additionalLRs = additionalResources; + containerInfo.credentials = credentials; + containerInfo.credentialsChanged = credentialsChanged; + containerInfo.taskPulled = false; + + ContainerId oldId = attemptToContainerMap.putIfAbsent(new TaskAttempt(taskSpec.getTaskAttemptID()), containerId); + if (oldId != null) { + throw new TezUncheckedException( + "Attempting to register an already registered taskAttempt with id: " + + taskSpec.getTaskAttemptID() + " to containerId: " + containerId + + ". Already registered to containerId: " + oldId); + } + } + + } + + @Override + public void unregisterRunningTaskAttempt(TezTaskAttemptID taskAttemptID) { + TaskAttempt taskAttempt = new TaskAttempt(taskAttemptID); + ContainerId containerId = attemptToContainerMap.remove(taskAttempt); + if(containerId == null) { + LOG.warn("Unregister task attempt: " + taskAttempt + " from unknown container"); + return; + } + ContainerInfo containerInfo = registeredContainers.get(containerId); + if (containerInfo == null) { + LOG.warn("Unregister task attempt: " + taskAttempt + + " from non-registered container: " + containerId); + return; + } + synchronized (containerInfo) { + containerInfo.reset(); + attemptToContainerMap.remove(taskAttempt); + } + } + + @Override + public InetSocketAddress getAddress() { + return address; + } + + public TezTaskUmbilicalProtocol getUmbilical() { + return this.taskUmbilical; + } + + private class TezTaskUmbilicalProtocolImpl implements TezTaskUmbilicalProtocol { + + @Override + public ContainerTask getTask(ContainerContext containerContext) throws IOException { + ContainerTask task = null; + if (containerContext == null || containerContext.getContainerIdentifier() == null) { + LOG.info("Invalid task request with an empty containerContext or containerId"); + task = TASK_FOR_INVALID_JVM; + } else { + ContainerId containerId = ConverterUtils.toContainerId(containerContext + .getContainerIdentifier()); + if (LOG.isDebugEnabled()) { + LOG.debug("Container with id: " + containerId + " asked for a task"); + } + task = getContainerTask(containerId); + if (task != null && !task.shouldDie()) { + taskCommunicatorContext + .taskStartedRemotely(task.getTaskSpec().getTaskAttemptID(), containerId); + } + } + if (LOG.isDebugEnabled()) { + LOG.debug("getTask returning task: " + task); + } + return task; + } + + @Override + public boolean canCommit(TezTaskAttemptID taskAttemptId) throws IOException { + return taskCommunicatorContext.canCommit(taskAttemptId); + } + + @Override + public TezHeartbeatResponse heartbeat(TezHeartbeatRequest request) throws IOException, + TezException { + ContainerId containerId = ConverterUtils.toContainerId(request.getContainerIdentifier()); + long requestId = request.getRequestId(); + if (LOG.isDebugEnabled()) { + LOG.debug("Received heartbeat from container" + + ", request=" + request); + } + + ContainerInfo containerInfo = registeredContainers.get(containerId); + if (containerInfo == null) { + LOG.warn("Received task heartbeat from unknown container with id: " + containerId + + ", asking it to die"); + TezHeartbeatResponse response = new TezHeartbeatResponse(); + response.setLastRequestId(requestId); + response.setShouldDie(); + return response; + } + + synchronized (containerInfo) { + if (containerInfo.lastRequestId == requestId) { + LOG.warn("Old sequenceId received: " + requestId + + ", Re-sending last response to client"); + return containerInfo.lastResponse; + } + } + + TaskHeartbeatResponse tResponse = null; + + + TezTaskAttemptID taskAttemptID = request.getCurrentTaskAttemptID(); + if (taskAttemptID != null) { + synchronized (containerInfo) { + ContainerId containerIdFromMap = attemptToContainerMap.get(new TaskAttempt(taskAttemptID)); + if (containerIdFromMap == null || !containerIdFromMap.equals(containerId)) { + throw new TezException("Attempt " + taskAttemptID + + " is not recognized for heartbeat"); + } + + if (containerInfo.lastRequestId + 1 != requestId) { + throw new TezException("Container " + containerId + + " has invalid request id. Expected: " + + containerInfo.lastRequestId + 1 + + " and actual: " + requestId); + } + } + TaskHeartbeatRequest tRequest = new TaskHeartbeatRequest(request.getContainerIdentifier(), + request.getCurrentTaskAttemptID(), request.getEvents(), request.getStartIndex(), + request.getMaxEvents()); + tResponse = taskCommunicatorContext.heartbeat(tRequest); + } + TezHeartbeatResponse response; + if (tResponse == null) { + response = new TezHeartbeatResponse(); + } else { + response = new TezHeartbeatResponse(tResponse.getEvents()); + } + response.setLastRequestId(requestId); + containerInfo.lastRequestId = requestId; + containerInfo.lastResponse = response; + return response; + } + + + // TODO Remove this method once we move to the Protobuf RPC engine + @Override + public long getProtocolVersion(String protocol, long clientVersion) throws IOException { + return versionID; + } + + // TODO Remove this method once we move to the Protobuf RPC engine + @Override + public ProtocolSignature getProtocolSignature(String protocol, long clientVersion, + int clientMethodsHash) throws IOException { + return ProtocolSignature.getProtocolSignature(this, protocol, + clientVersion, clientMethodsHash); + } + } + + private ContainerTask getContainerTask(ContainerId containerId) throws IOException { + ContainerInfo containerInfo = registeredContainers.get(containerId); + ContainerTask task = null; + if (containerInfo == null) { + if (taskCommunicatorContext.isKnownContainer(containerId)) { + LOG.info("Container with id: " + containerId + + " is valid, but no longer registered, and will be killed"); + } else { + LOG.info("Container with id: " + containerId + + " is invalid and will be killed"); + } + task = TASK_FOR_INVALID_JVM; + } else { + synchronized (containerInfo) { + if (containerInfo.taskSpec != null) { + if (!containerInfo.taskPulled) { + containerInfo.taskPulled = true; + task = constructContainerTask(containerInfo); + } else { + if (LOG.isDebugEnabled()) { + LOG.debug("Task " + containerInfo.taskSpec.getTaskAttemptID() + + " already sent to container: " + containerId); + } + task = null; + } + } else { + task = null; + if (LOG.isDebugEnabled()) { + LOG.debug("No task assigned yet for running container: " + containerId); + } + } + } + } + return task; + } + + private ContainerTask constructContainerTask(ContainerInfo containerInfo) throws IOException { + return new ContainerTask(containerInfo.taskSpec, false, + convertLocalResourceMap(containerInfo.additionalLRs), containerInfo.credentials, + containerInfo.credentialsChanged); + } + + private Map<String, TezLocalResource> convertLocalResourceMap(Map<String, LocalResource> ylrs) + throws IOException { + Map<String, TezLocalResource> tlrs = Maps.newHashMap(); + if (ylrs != null) { + for (Map.Entry<String, LocalResource> ylrEntry : ylrs.entrySet()) { + TezLocalResource tlr; + try { + tlr = TezConverterUtils.convertYarnLocalResourceToTez(ylrEntry.getValue()); + } catch (URISyntaxException e) { + throw new IOException(e); + } + tlrs.put(ylrEntry.getKey(), tlr); + } + } + return tlrs; + } + + + // Holder for Task information, which eventually will likely be VertexImplm taskIndex, attemptIndex + private static class TaskAttempt { + // TODO TEZ-2003 Change this to work with VertexName, int id, int version + // TODO TEZ-2003 Avoid constructing this unit all over the place + private TezTaskAttemptID taskAttemptId; + + TaskAttempt(TezTaskAttemptID taskAttemptId) { + this.taskAttemptId = taskAttemptId; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof TaskAttempt)) { + return false; + } + + TaskAttempt that = (TaskAttempt) o; + + if (!taskAttemptId.equals(that.taskAttemptId)) { + return false; + } + + return true; + } + + @Override + public int hashCode() { + return taskAttemptId.hashCode(); + } + + @Override + public String toString() { + return "TaskAttempt{" + "taskAttemptId=" + taskAttemptId + '}'; + } + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java index a5cab86..d9d668f 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java @@ -59,6 +59,8 @@ import org.apache.tez.dag.api.TezException; import org.apache.tez.dag.api.TezUncheckedException; import org.apache.tez.dag.app.AppContext; import org.apache.tez.dag.app.TaskAttemptListener; +import org.apache.tez.dag.app.TaskAttemptListenerImpTezDag; +import org.apache.tez.dag.app.TezTaskCommunicatorImpl; import org.apache.tez.dag.app.rm.NMCommunicatorEvent; import org.apache.tez.dag.app.rm.NMCommunicatorLaunchRequestEvent; import org.apache.tez.dag.app.rm.NMCommunicatorStopRequestEvent; @@ -86,7 +88,7 @@ public class LocalContainerLauncher extends AbstractService implements private static final Logger LOG = LoggerFactory.getLogger(LocalContainerLauncher.class); private final AppContext context; - private final TaskAttemptListener taskAttemptListener; + private final TezTaskUmbilicalProtocol taskUmbilicalProtocol; private final AtomicBoolean serviceStopped = new AtomicBoolean(false); private final String workingDirectory; private final Map<String, String> localEnv = new HashMap<String, String>(); @@ -114,7 +116,9 @@ public class LocalContainerLauncher extends AbstractService implements String workingDirectory) throws UnknownHostException { super(LocalContainerLauncher.class.getName()); this.context = context; - this.taskAttemptListener = taskAttemptListener; + TaskAttemptListenerImpTezDag taListener = (TaskAttemptListenerImpTezDag)taskAttemptListener; + TezTaskCommunicatorImpl taskComm = (TezTaskCommunicatorImpl) taListener.getTaskCommunicator(); + this.taskUmbilicalProtocol = taskComm.getUmbilical(); this.workingDirectory = workingDirectory; AuxiliaryServiceHelper.setServiceDataIntoEnv( ShuffleUtils.SHUFFLE_HANDLER_SERVICE_ID, ByteBuffer.allocate(4).putInt(0), localEnv); @@ -219,7 +223,7 @@ public class LocalContainerLauncher extends AbstractService implements tezChild = createTezChild(context.getAMConf(), event.getContainerId(), tokenIdentifier, context.getApplicationAttemptId().getAttemptId(), context.getLocalDirs(), - (TezTaskUmbilicalProtocol) taskAttemptListener, + taskUmbilicalProtocol, TezCommonUtils.parseCredentialsBytes(event.getContainerLaunchContext().getTokens().array())); } catch (InterruptedException e) { handleLaunchFailed(e, event.getContainerId()); http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java index a6b403d..0fc2e12 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java @@ -22,6 +22,7 @@ import java.util.List; import org.apache.hadoop.yarn.api.records.Container; import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.NodeId; import org.apache.hadoop.yarn.event.EventHandler; import org.apache.tez.dag.records.TezTaskAttemptID; @@ -32,5 +33,5 @@ public interface AMContainer extends EventHandler<AMContainerEvent>{ public Container getContainer(); public List<TezTaskAttemptID> getAllTaskAttempts(); public TezTaskAttemptID getCurrentTaskAttempt(); - + } http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventAssignTA.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventAssignTA.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventAssignTA.java index 682cd02..0398882 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventAssignTA.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventAssignTA.java @@ -27,6 +27,8 @@ import org.apache.tez.runtime.api.impl.TaskSpec; public class AMContainerEventAssignTA extends AMContainerEvent { + // TODO TEZ-2003. Add the task priority to this event. + private final TezTaskAttemptID attemptId; // TODO Maybe have tht TAL pull the remoteTask from the TaskAttempt itself ? private final TaskSpec remoteTaskSpec; http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java index 330f2b7..1acec9c 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java @@ -35,6 +35,7 @@ import org.apache.hadoop.yarn.api.records.ContainerExitStatus; import org.apache.hadoop.yarn.api.records.ContainerId; import org.apache.hadoop.yarn.api.records.ContainerLaunchContext; import org.apache.hadoop.yarn.api.records.LocalResource; +import org.apache.hadoop.yarn.api.records.NodeId; import org.apache.hadoop.yarn.event.Event; import org.apache.hadoop.yarn.event.EventHandler; import org.apache.hadoop.yarn.state.InvalidStateTransitonException; http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java b/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java index 8fa57d3..24f3019 100644 --- a/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java @@ -37,6 +37,7 @@ import java.util.concurrent.atomic.AtomicLong; import org.apache.tez.dag.app.dag.DAG; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.security.Credentials; import org.apache.hadoop.service.AbstractService; @@ -50,7 +51,10 @@ import org.apache.tez.client.TezApiVersionInfo; import org.apache.tez.common.ContainerContext; import org.apache.tez.common.ContainerTask; import org.apache.tez.common.counters.TezCounters; +import org.apache.tez.dag.api.TaskHeartbeatRequest; +import org.apache.tez.dag.api.TaskHeartbeatResponse; import org.apache.tez.dag.api.TezConfiguration; +import org.apache.tez.dag.api.TaskCommunicator; import org.apache.tez.dag.api.TezUncheckedException; import org.apache.tez.dag.app.launcher.ContainerLauncher; import org.apache.tez.dag.app.rm.NMCommunicatorEvent; @@ -72,8 +76,6 @@ import org.apache.tez.runtime.api.impl.TaskSpec; import org.apache.tez.runtime.api.impl.TaskStatistics; import org.apache.tez.runtime.api.impl.TezEvent; import org.apache.tez.runtime.api.impl.EventMetaData.EventProducerConsumerType; -import org.apache.tez.runtime.api.impl.TezHeartbeatRequest; -import org.apache.tez.runtime.api.impl.TezHeartbeatResponse; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; @@ -130,6 +132,7 @@ public class MockDAGAppMaster extends DAGAppMaster { Map<ContainerId, ContainerData> containers = Maps.newConcurrentMap(); ArrayBlockingQueue<Worker> workers; TaskAttemptListenerImpTezDag taListener; + TezTaskCommunicatorImpl taskCommunicator; AtomicBoolean startScheduling = new AtomicBoolean(true); AtomicBoolean goFlag; @@ -194,6 +197,7 @@ public class MockDAGAppMaster extends DAGAppMaster { @Override public void serviceStart() throws Exception { taListener = (TaskAttemptListenerImpTezDag) getTaskAttemptListener(); + taskCommunicator = (TezTaskCommunicatorImpl) taListener.getTaskCommunicator(); eventHandlingThread = new Thread(this); eventHandlingThread.start(); ExecutorService rawExecutor = Executors.newFixedThreadPool(handlerConcurrency, @@ -333,10 +337,10 @@ public class MockDAGAppMaster extends DAGAppMaster { } } - private void doHeartbeat(TezHeartbeatRequest request, ContainerData cData) throws Exception { + private void doHeartbeat(TaskHeartbeatRequest request, ContainerData cData) throws Exception { long startTime = System.nanoTime(); long startCpuTime = threadMxBean.getCurrentThreadCpuTime(); - TezHeartbeatResponse response = taListener.heartbeat(request); + TaskHeartbeatResponse response = taListener.heartbeat(request); if (response.shouldDie()) { cData.remove(); } else { @@ -388,7 +392,8 @@ public class MockDAGAppMaster extends DAGAppMaster { try { if (cData.taId == null) { // if container is not assigned a task, ask for a task - ContainerTask cTask = taListener.getTask(new ContainerContext(cData.cIdStr)); + ContainerTask cTask = + taskCommunicator.getUmbilical().getTask(new ContainerContext(cData.cIdStr)); if (cTask != null) { if (cTask.shouldDie()) { cData.remove(); @@ -424,8 +429,11 @@ public class MockDAGAppMaster extends DAGAppMaster { events.add(new TezEvent(new TaskStatusUpdateEvent(counters, progress, stats), new EventMetaData( EventProducerConsumerType.SYSTEM, cData.vName, "", cData.taId), getContext().getClock().getTime())); - TezHeartbeatRequest request = new TezHeartbeatRequest(cData.numUpdates, events, - cData.nextPreRoutedFromEventId, cData.cIdStr, cData.taId, cData.nextFromEventId, 50000); +// TezHeartbeatRequest request = new TezHeartbeatRequest(cData.numUpdates, events, +// cData.cIdStr, cData.taId, cData.nextFromEventId, 50000); + TaskHeartbeatRequest request = + new TaskHeartbeatRequest(cData.cIdStr, cData.taId, events, cData.nextFromEventId, cData.nextPreRoutedFromEventId, + 50000); doHeartbeat(request, cData); } else if (version != null && cData.taId.getId() <= version.intValue()) { preemptContainer(cData); @@ -436,8 +444,9 @@ public class MockDAGAppMaster extends DAGAppMaster { new TaskAttemptCompletedEvent(), new EventMetaData( EventProducerConsumerType.SYSTEM, cData.vName, "", cData.taId), getContext().getClock().getTime())); - TezHeartbeatRequest request = new TezHeartbeatRequest(++cData.numUpdates, events, - cData.nextPreRoutedFromEventId, cData.cIdStr, cData.taId, cData.nextFromEventId, 10000); + TaskHeartbeatRequest request = + new TaskHeartbeatRequest(cData.cIdStr, cData.taId, events, cData.nextFromEventId, cData.nextPreRoutedFromEventId, + 10000); doHeartbeat(request, cData); cData.clear(); } http://git-wip-us.apache.org/repos/asf/tez/blob/7ab75d82/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java index d8a7388..c454c7c 100644 --- a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java @@ -1,16 +1,16 @@ /* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ package org.apache.tez.dag.app; @@ -19,6 +19,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -37,6 +38,7 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.api.records.ApplicationAccessType; import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.Container; import org.apache.hadoop.yarn.api.records.ContainerId; import org.apache.hadoop.yarn.event.Event; import org.apache.hadoop.yarn.event.EventHandler; @@ -45,6 +47,12 @@ import org.apache.tez.common.ContainerTask; import org.apache.tez.common.security.JobTokenSecretManager; import org.apache.tez.dag.api.TezConfiguration; import org.apache.tez.dag.api.TezException; +import org.apache.hadoop.yarn.api.records.NodeId; +import org.apache.hadoop.yarn.event.EventHandler; +import org.apache.tez.common.ContainerContext; +import org.apache.tez.common.ContainerTask; +import org.apache.tez.common.TezTaskUmbilicalProtocol; +import org.apache.tez.dag.api.TaskCommunicatorContext; import org.apache.tez.dag.app.dag.DAG; import org.apache.tez.dag.app.dag.Vertex; import org.apache.tez.dag.app.dag.event.TaskAttemptEventType; @@ -108,8 +116,18 @@ public class TestTaskAttemptListenerImplTezDag { doReturn(amContainerMap).when(appContext).getAllContainers(); doReturn(clock).when(appContext).getClock(); - taskAttemptListener = new TaskAttemptListenerImplForTest(appContext, - mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class), null); + NodeId nodeId = NodeId.newInstance("localhost", 0); + AMContainer amContainer = mock(AMContainer.class); + Container container = mock(Container.class); + doReturn(nodeId).when(container).getNodeId(); + doReturn(amContainer).when(amContainerMap).get(any(ContainerId.class)); + doReturn(container).when(amContainer).getContainer(); + + taskAttemptListener = + new TaskAttemptListenerImpTezDag(appContext, mock(TaskHeartbeatHandler.class), + mock(ContainerHeartbeatHandler.class), null); + TezTaskCommunicatorImpl taskCommunicator = (TezTaskCommunicatorImpl)taskAttemptListener.getTaskCommunicator(); + TezTaskUmbilicalProtocol tezUmbilical = taskCommunicator.getUmbilical(); taskSpec = mock(TaskSpec.class); doReturn(taskAttemptID).when(taskSpec).getTaskAttemptID(); @@ -121,32 +139,30 @@ public class TestTaskAttemptListenerImplTezDag { public void testGetTask() throws IOException { ContainerId containerId1 = createContainerId(appId, 1); - doReturn(mock(AMContainer.class)).when(amContainerMap).get(containerId1); ContainerContext containerContext1 = new ContainerContext(containerId1.toString()); - containerTask = taskAttemptListener.getTask(containerContext1); + containerTask = tezUmbilical.getTask(containerContext1); assertTrue(containerTask.shouldDie()); ContainerId containerId2 = createContainerId(appId, 2); - doReturn(mock(AMContainer.class)).when(amContainerMap).get(containerId2); ContainerContext containerContext2 = new ContainerContext(containerId2.toString()); taskAttemptListener.registerRunningContainer(containerId2); - containerTask = taskAttemptListener.getTask(containerContext2); + containerTask = tezUmbilical.getTask(containerContext2); assertNull(containerTask); // Valid task registered taskAttemptListener.registerTaskAttempt(amContainerTask, containerId2); - containerTask = taskAttemptListener.getTask(containerContext2); + containerTask = tezUmbilical.getTask(containerContext2); assertFalse(containerTask.shouldDie()); assertEquals(taskSpec, containerTask.getTaskSpec()); // Task unregistered. Should respond to heartbeats - taskAttemptListener.unregisterTaskAttempt(taskAttemptID); - containerTask = taskAttemptListener.getTask(containerContext2); + taskAttemptListener.unregisterTaskAttempt(taskAttemptId); + containerTask = tezUmbilical.getTask(containerContext2); assertNull(containerTask); // Container unregistered. Should send a shouldDie = true taskAttemptListener.unregisterRunningContainer(containerId2); - containerTask = taskAttemptListener.getTask(containerContext2); + containerTask = tezUmbilical.getTask(containerContext2); assertTrue(containerTask.shouldDie()); ContainerId containerId3 = createContainerId(appId, 3); @@ -160,27 +176,30 @@ public class TestTaskAttemptListenerImplTezDag { AMContainerTask amContainerTask2 = new AMContainerTask(taskSpec, null, null, false, 0); taskAttemptListener.registerTaskAttempt(amContainerTask2, containerId3); taskAttemptListener.unregisterRunningContainer(containerId3); - containerTask = taskAttemptListener.getTask(containerContext3); + containerTask = tezUmbilical.getTask(containerContext3); assertTrue(containerTask.shouldDie()); } @Test(timeout = 5000) public void testGetTaskMultiplePulls() throws IOException { + TezTaskCommunicatorImpl taskCommunicator = (TezTaskCommunicatorImpl)taskAttemptListener.getTaskCommunicator(); + TezTaskUmbilicalProtocol tezUmbilical = taskCommunicator.getUmbilical(); + ContainerId containerId1 = createContainerId(appId, 1); doReturn(mock(AMContainer.class)).when(amContainerMap).get(containerId1); ContainerContext containerContext1 = new ContainerContext(containerId1.toString()); taskAttemptListener.registerRunningContainer(containerId1); - containerTask = taskAttemptListener.getTask(containerContext1); + containerTask = tezUmbilical.getTask(containerContext1); assertNull(containerTask); // Register task taskAttemptListener.registerTaskAttempt(amContainerTask, containerId1); - containerTask = taskAttemptListener.getTask(containerContext1); + containerTask = tezUmbilical.getTask(containerContext1); assertFalse(containerTask.shouldDie()); assertEquals(taskSpec, containerTask.getTaskSpec()); // Try pulling again - simulates re-use pull - containerTask = taskAttemptListener.getTask(containerContext1); + containerTask = tezUmbilical.getTask(containerContext1); assertNull(containerTask); } @@ -325,13 +344,11 @@ public class TestTaskAttemptListenerImplTezDag { return ContainerId.newInstance(appAttemptId, containerIdx); } - private static class TaskAttemptListenerImplForTest extends TaskAttemptListenerImpTezDag { + private static class TezTaskCommunicatorImplForTest extends TezTaskCommunicatorImpl { - public TaskAttemptListenerImplForTest(AppContext context, - TaskHeartbeatHandler thh, - ContainerHeartbeatHandler chh, - JobTokenSecretManager jobTokenSecretManager) { - super(context, thh, chh, jobTokenSecretManager); + public TezTaskCommunicatorImplForTest( + TaskCommunicatorContext taskCommunicatorContext) { + super(taskCommunicatorContext); } @Override
