http://git-wip-us.apache.org/repos/asf/tez/blob/36e7f854/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/ContainerRunnerImpl.java ---------------------------------------------------------------------- diff --git a/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/ContainerRunnerImpl.java b/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/ContainerRunnerImpl.java new file mode 100644 index 0000000..4a6ce33 --- /dev/null +++ b/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/ContainerRunnerImpl.java @@ -0,0 +1,512 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.service.impl; + +import java.io.File; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.security.PrivilegedExceptionAction; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import com.google.common.base.Preconditions; +import com.google.common.base.Stopwatch; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DataInputBuffer; +import org.apache.hadoop.ipc.RPC; +import org.apache.hadoop.net.NetUtils; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.SecurityUtil; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.service.AbstractService; +import org.apache.hadoop.yarn.api.ApplicationConstants; +import org.apache.hadoop.yarn.util.AuxiliaryServiceHelper; +import org.apache.log4j.Logger; +import org.apache.tez.common.TezCommonUtils; +import org.apache.tez.common.TezTaskUmbilicalProtocol; +import org.apache.tez.common.security.JobTokenIdentifier; +import org.apache.tez.common.security.TokenCache; +import org.apache.tez.dag.api.TezConfiguration; +import org.apache.tez.dag.api.TezException; +import org.apache.tez.runtime.task.TaskReporter; +import org.apache.tez.runtime.task.TezTaskRunner; +import org.apache.tez.service.ContainerRunner; +import org.apache.tez.dag.api.TezConstants; +import org.apache.tez.runtime.api.ExecutionContext; +import org.apache.tez.runtime.api.impl.ExecutionContextImpl; +import org.apache.tez.runtime.common.objectregistry.ObjectRegistryImpl; +import org.apache.tez.runtime.task.TezChild; +import org.apache.tez.runtime.task.TezChild.ContainerExecutionResult; +import org.apache.tez.shufflehandler.ShuffleHandler; +import org.apache.tez.test.service.rpc.TezTestServiceProtocolProtos.RunContainerRequestProto; +import org.apache.tez.test.service.rpc.TezTestServiceProtocolProtos.SubmitWorkRequestProto; +import org.apache.tez.util.ProtoConverters; + +public class ContainerRunnerImpl extends AbstractService implements ContainerRunner { + + private static final Logger LOG = Logger.getLogger(ContainerRunnerImpl.class); + + private final ListeningExecutorService executorService; + private final AtomicReference<InetSocketAddress> localAddress; + private final String[] localDirsBase; + private final Map<String, String> localEnv = new HashMap<String, String>(); + private volatile FileSystem localFs; + private final long memoryPerExecutor; + // TODO Support for removing queued containers, interrupting / killing specific containers - when preemption is supported + + + + + public ContainerRunnerImpl(int numExecutors, String[] localDirsBase, + AtomicReference<InetSocketAddress> localAddress, + long totalMemoryAvailableBytes) { + super("ContainerRunnerImpl"); + Preconditions.checkState(numExecutors > 0, + "Invalid number of executors: " + numExecutors + ". Must be > 0"); + this.localDirsBase = localDirsBase; + this.localAddress = localAddress; + + ExecutorService raw = Executors.newFixedThreadPool(numExecutors, + new ThreadFactoryBuilder().setNameFormat("ContainerExecutor %d").build()); + this.executorService = MoreExecutors.listeningDecorator(raw); + + + // 80% of memory considered for accounted buffers. Rest for objects. + // TODO Tune this based on the available size. + this.memoryPerExecutor = (long)(totalMemoryAvailableBytes * 0.8 / (float) numExecutors); + + LOG.info("ContainerRunnerImpl config: " + + "memoryPerExecutorDerived=" + memoryPerExecutor + + ", numExecutors=" + numExecutors + ); + } + + @Override + public void serviceInit(Configuration conf) { + try { + localFs = FileSystem.getLocal(conf); + } catch (IOException e) { + throw new RuntimeException("Failed to setup local filesystem instance", e); + } + } + + @Override + public void serviceStart() { + } + + public void setShufflePort(int shufflePort) { + AuxiliaryServiceHelper.setServiceDataIntoEnv( + TezConstants.TEZ_SHUFFLE_HANDLER_SERVICE_ID, + ByteBuffer.allocate(4).putInt(shufflePort), localEnv); + } + + @Override + protected void serviceStop() throws Exception { + super.serviceStop(); + } + + // TODO Move this into a utilities class + private static String createAppSpecificLocalDir(String baseDir, String applicationIdString, + String user) { + return baseDir + File.separator + "usercache" + File.separator + user + File.separator + + "appcache" + File.separator + applicationIdString; + } + + /** + * Submit a container which is ready for running. + * The regular pull mechanism will be used to fetch work from the AM + * @param request + * @throws IOException + */ + @Override + public void queueContainer(RunContainerRequestProto request) throws IOException { + LOG.info("Queuing container for execution: " + request); + + Map<String, String> env = new HashMap<String, String>(); + env.putAll(localEnv); + env.put(ApplicationConstants.Environment.USER.name(), request.getUser()); + + String[] localDirs = new String[localDirsBase.length]; + + // Setup up local dirs to be application specific, and create them. + for (int i = 0; i < localDirsBase.length; i++) { + localDirs[i] = createAppSpecificLocalDir(localDirsBase[i], request.getApplicationIdString(), + request.getUser()); + localFs.mkdirs(new Path(localDirs[i])); + } + LOG.info("DEBUG: Dirs are: " + Arrays.toString(localDirs)); + + + // Setup workingDir. This is otherwise setup as Environment.PWD + // Used for re-localization, to add the user specified configuration (conf_pb_binary_stream) + String workingDir = localDirs[0]; + + Credentials credentials = new Credentials(); + DataInputBuffer dib = new DataInputBuffer(); + byte[] tokenBytes = request.getCredentialsBinary().toByteArray(); + dib.reset(tokenBytes, tokenBytes.length); + credentials.readTokenStorageStream(dib); + + Token<JobTokenIdentifier> jobToken = TokenCache.getSessionToken(credentials); + + // TODO Unregistering does not happen at the moment, since there's no signals on when an app completes. + LOG.info("DEBUG: Registering request with the ShuffleHandler"); + ShuffleHandler.get().registerApplication(request.getApplicationIdString(), jobToken, request.getUser()); + + + ContainerRunnerCallable callable = new ContainerRunnerCallable(request, new Configuration(getConfig()), + new ExecutionContextImpl(localAddress.get().getHostName()), env, localDirs, + workingDir, credentials, memoryPerExecutor); + ListenableFuture<ContainerExecutionResult> future = executorService + .submit(callable); + Futures.addCallback(future, new ContainerRunnerCallback(request, callable)); + } + + /** + * Submit an entire work unit - containerId + TaskSpec. + * This is intended for a task push from the AM + * + * @param request + * @throws IOException + */ + @Override + public void submitWork(SubmitWorkRequestProto request) throws + IOException { + LOG.info("Queuing work for execution: " + request); + + Map<String, String> env = new HashMap<String, String>(); + env.putAll(localEnv); + env.put(ApplicationConstants.Environment.USER.name(), request.getUser()); + + String[] localDirs = new String[localDirsBase.length]; + + // Setup up local dirs to be application specific, and create them. + for (int i = 0; i < localDirsBase.length; i++) { + localDirs[i] = createAppSpecificLocalDir(localDirsBase[i], request.getApplicationIdString(), + request.getUser()); + localFs.mkdirs(new Path(localDirs[i])); + } + if (LOG.isDebugEnabled()) { + LOG.debug("Dirs are: " + Arrays.toString(localDirs)); + } + + // Setup workingDir. This is otherwise setup as Environment.PWD + // Used for re-localization, to add the user specified configuration (conf_pb_binary_stream) + String workingDir = localDirs[0]; + + Credentials credentials = new Credentials(); + DataInputBuffer dib = new DataInputBuffer(); + byte[] tokenBytes = request.getCredentialsBinary().toByteArray(); + dib.reset(tokenBytes, tokenBytes.length); + credentials.readTokenStorageStream(dib); + + Token<JobTokenIdentifier> jobToken = TokenCache.getSessionToken(credentials); + + // TODO Unregistering does not happen at the moment, since there's no signals on when an app completes. + LOG.info("DEBUG: Registering request with the ShuffleHandler"); + ShuffleHandler.get().registerApplication(request.getApplicationIdString(), jobToken, request.getUser()); + TaskRunnerCallable callable = new TaskRunnerCallable(request, new Configuration(getConfig()), + new ExecutionContextImpl(localAddress.get().getHostName()), env, localDirs, + workingDir, credentials, memoryPerExecutor); + ListenableFuture<ContainerExecutionResult> future = executorService.submit(callable); + Futures.addCallback(future, new TaskRunnerCallback(request, callable)); + } + + + static class ContainerRunnerCallable implements Callable<ContainerExecutionResult> { + + private final RunContainerRequestProto request; + private final Configuration conf; + private final String workingDir; + private final String[] localDirs; + private final Map<String, String> envMap; + private final String pid = null; + private final ObjectRegistryImpl objectRegistry; + private final ExecutionContext executionContext; + private final Credentials credentials; + private final long memoryAvailable; + private volatile TezChild tezChild; + + + ContainerRunnerCallable(RunContainerRequestProto request, Configuration conf, + ExecutionContext executionContext, Map<String, String> envMap, + String[] localDirs, String workingDir, Credentials credentials, + long memoryAvailable) { + this.request = request; + this.conf = conf; + this.executionContext = executionContext; + this.envMap = envMap; + this.workingDir = workingDir; + this.localDirs = localDirs; + this.objectRegistry = new ObjectRegistryImpl(); + this.credentials = credentials; + this.memoryAvailable = memoryAvailable; + + } + + @Override + public ContainerExecutionResult call() throws Exception { + Stopwatch sw = new Stopwatch().start(); + tezChild = + new TezChild(conf, request.getAmHost(), request.getAmPort(), + request.getContainerIdString(), + request.getTokenIdentifier(), request.getAppAttemptNumber(), workingDir, localDirs, + envMap, objectRegistry, pid, + executionContext, credentials, memoryAvailable, request.getUser()); + ContainerExecutionResult result = tezChild.run(); + LOG.info("ExecutionTime for Container: " + request.getContainerIdString() + "=" + + sw.stop().elapsedMillis()); + return result; + } + + public TezChild getTezChild() { + return this.tezChild; + } + } + + + final class ContainerRunnerCallback implements FutureCallback<ContainerExecutionResult> { + + private final RunContainerRequestProto request; + private final ContainerRunnerCallable containerRunnerCallable; + + ContainerRunnerCallback(RunContainerRequestProto request, + ContainerRunnerCallable containerRunnerCallable) { + this.request = request; + this.containerRunnerCallable = containerRunnerCallable; + } + + // TODO Proper error handling + @Override + public void onSuccess(ContainerExecutionResult result) { + switch (result.getExitStatus()) { + case SUCCESS: + LOG.info("Successfully finished: " + request.getApplicationIdString() + ", containerId=" + + request.getContainerIdString()); + break; + case EXECUTION_FAILURE: + LOG.info("Failed to run: " + request.getApplicationIdString() + ", containerId=" + + request.getContainerIdString(), result.getThrowable()); + break; + case INTERRUPTED: + LOG.info( + "Interrupted while running: " + request.getApplicationIdString() + ", containerId=" + + request.getContainerIdString(), result.getThrowable()); + break; + case ASKED_TO_DIE: + LOG.info( + "Asked to die while running: " + request.getApplicationIdString() + ", containerId=" + + request.getContainerIdString()); + break; + } + } + + @Override + public void onFailure(Throwable t) { + LOG.error( + "TezChild execution failed for : " + request.getApplicationIdString() + ", containerId=" + + request.getContainerIdString(), t); + TezChild tezChild = containerRunnerCallable.getTezChild(); + if (tezChild != null) { + tezChild.shutdown(); + } + } + } + + static class TaskRunnerCallable implements Callable<ContainerExecutionResult> { + + private final SubmitWorkRequestProto request; + private final Configuration conf; + private final String workingDir; + private final String[] localDirs; + private final Map<String, String> envMap; + private final String pid = null; + private final ObjectRegistryImpl objectRegistry; + private final ExecutionContext executionContext; + private final Credentials credentials; + private final long memoryAvailable; + private final ListeningExecutorService executor; + private volatile TezTaskRunner taskRunner; + private volatile TaskReporter taskReporter; + private TezTaskUmbilicalProtocol umbilical; + + + TaskRunnerCallable(SubmitWorkRequestProto request, Configuration conf, + ExecutionContext executionContext, Map<String, String> envMap, + String[] localDirs, String workingDir, Credentials credentials, + long memoryAvailable) { + this.request = request; + this.conf = conf; + this.executionContext = executionContext; + this.envMap = envMap; + this.workingDir = workingDir; + this.localDirs = localDirs; + this.objectRegistry = new ObjectRegistryImpl(); + this.credentials = credentials; + this.memoryAvailable = memoryAvailable; + // TODO This executor seems unnecessary. Here and TezChild + ExecutorService executorReal = Executors.newFixedThreadPool(1, new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("TezTaskRunner_" + request.getTaskSpec().getTaskAttemptIdString()).build()); + executor = MoreExecutors.listeningDecorator(executorReal); + } + + @Override + public ContainerExecutionResult call() throws Exception { + + // TODO Consolidate this code with TezChild. + Stopwatch sw = new Stopwatch().start(); + UserGroupInformation taskUgi = UserGroupInformation.createRemoteUser(request.getUser()); + taskUgi.addCredentials(credentials); + + Token<JobTokenIdentifier> jobToken = TokenCache.getSessionToken(credentials); + Map<String, ByteBuffer> serviceConsumerMetadata = new HashMap<String, ByteBuffer>(); + serviceConsumerMetadata.put(TezConstants.TEZ_SHUFFLE_HANDLER_SERVICE_ID, + TezCommonUtils.convertJobTokenToBytes(jobToken)); + Multimap<String, String> startedInputsMap = HashMultimap.create(); + + UserGroupInformation taskOwner = + UserGroupInformation.createRemoteUser(request.getTokenIdentifier()); + final InetSocketAddress address = + NetUtils.createSocketAddrForHost(request.getAmHost(), request.getAmPort()); + SecurityUtil.setTokenService(jobToken, address); + taskOwner.addToken(jobToken); + umbilical = taskOwner.doAs(new PrivilegedExceptionAction<TezTaskUmbilicalProtocol>() { + @Override + public TezTaskUmbilicalProtocol run() throws Exception { + return RPC.getProxy(TezTaskUmbilicalProtocol.class, + TezTaskUmbilicalProtocol.versionID, address, conf); + } + }); + // TODO Stop reading this on each request. + taskReporter = new TaskReporter( + umbilical, + conf.getInt(TezConfiguration.TEZ_TASK_AM_HEARTBEAT_INTERVAL_MS, + TezConfiguration.TEZ_TASK_AM_HEARTBEAT_INTERVAL_MS_DEFAULT), + conf.getLong( + TezConfiguration.TEZ_TASK_AM_HEARTBEAT_COUNTER_INTERVAL_MS, + TezConfiguration.TEZ_TASK_AM_HEARTBEAT_COUNTER_INTERVAL_MS_DEFAULT), + conf.getInt(TezConfiguration.TEZ_TASK_MAX_EVENTS_PER_HEARTBEAT, + TezConfiguration.TEZ_TASK_MAX_EVENTS_PER_HEARTBEAT_DEFAULT), + new AtomicLong(0), + request.getContainerIdString()); + + taskRunner = new TezTaskRunner(conf, taskUgi, localDirs, + ProtoConverters.getTaskSpecfromProto(request.getTaskSpec()), umbilical, + request.getAppAttemptNumber(), + serviceConsumerMetadata, envMap, startedInputsMap, taskReporter, executor, objectRegistry, + pid, + executionContext, memoryAvailable); + + boolean shouldDie; + try { + shouldDie = !taskRunner.run(); + if (shouldDie) { + LOG.info("Got a shouldDie notification via hearbeats. Shutting down"); + return new ContainerExecutionResult(ContainerExecutionResult.ExitStatus.SUCCESS, null, + "Asked to die by the AM"); + } + } catch (IOException e) { + return new ContainerExecutionResult(ContainerExecutionResult.ExitStatus.EXECUTION_FAILURE, + e, "TaskExecutionFailure: " + e.getMessage()); + } catch (TezException e) { + return new ContainerExecutionResult(ContainerExecutionResult.ExitStatus.EXECUTION_FAILURE, + e, "TaskExecutionFailure: " + e.getMessage()); + } finally { + FileSystem.closeAllForUGI(taskUgi); + } + LOG.info("ExecutionTime for Container: " + request.getContainerIdString() + "=" + + sw.stop().elapsedMillis()); + return new ContainerExecutionResult(ContainerExecutionResult.ExitStatus.SUCCESS, null, + null); + } + + public void shutdown() { + executor.shutdownNow(); + if (taskReporter != null) { + taskReporter.shutdown(); + } + if (umbilical != null) { + RPC.stopProxy(umbilical); + } + } + } + + + final class TaskRunnerCallback implements FutureCallback<ContainerExecutionResult> { + + private final SubmitWorkRequestProto request; + private final TaskRunnerCallable taskRunnerCallable; + + TaskRunnerCallback(SubmitWorkRequestProto request, + TaskRunnerCallable containerRunnerCallable) { + this.request = request; + this.taskRunnerCallable = containerRunnerCallable; + } + + // TODO Proper error handling + @Override + public void onSuccess(ContainerExecutionResult result) { + switch (result.getExitStatus()) { + case SUCCESS: + LOG.info("Successfully finished: " + request.getApplicationIdString() + ", containerId=" + + request.getContainerIdString()); + break; + case EXECUTION_FAILURE: + LOG.info("Failed to run: " + request.getApplicationIdString() + ", containerId=" + + request.getContainerIdString(), result.getThrowable()); + break; + case INTERRUPTED: + LOG.info( + "Interrupted while running: " + request.getApplicationIdString() + ", containerId=" + + request.getContainerIdString(), result.getThrowable()); + break; + case ASKED_TO_DIE: + LOG.info( + "Asked to die while running: " + request.getApplicationIdString() + ", containerId=" + + request.getContainerIdString()); + break; + } + taskRunnerCallable.shutdown(); + } + + @Override + public void onFailure(Throwable t) { + LOG.error( + "TezTaskRunner execution failed for : " + request.getApplicationIdString() + ", containerId=" + + request.getContainerIdString(), t); + taskRunnerCallable.shutdown(); + } + } + +} \ No newline at end of file
http://git-wip-us.apache.org/repos/asf/tez/blob/36e7f854/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/TezTestService.java ---------------------------------------------------------------------- diff --git a/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/TezTestService.java b/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/TezTestService.java new file mode 100644 index 0000000..012e352 --- /dev/null +++ b/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/TezTestService.java @@ -0,0 +1,126 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.service.impl; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.service.AbstractService; +import org.apache.hadoop.util.StringUtils; +import org.apache.log4j.Logger; +import org.apache.tez.service.ContainerRunner; +import org.apache.tez.shufflehandler.ShuffleHandler; +import org.apache.tez.test.service.rpc.TezTestServiceProtocolProtos; +import org.apache.tez.test.service.rpc.TezTestServiceProtocolProtos.RunContainerRequestProto; + +public class TezTestService extends AbstractService implements ContainerRunner { + + private static final Logger LOG = Logger.getLogger(TezTestService.class); + + private final Configuration shuffleHandlerConf; + private final int numExecutors; + + private final TezTestServiceProtocolServerImpl server; + private final ContainerRunnerImpl containerRunner; + private final String[] localDirs; + + private final AtomicInteger numSubmissions = new AtomicInteger(0); + + + private final AtomicReference<InetSocketAddress> address = new AtomicReference<InetSocketAddress>(); + + public TezTestService(Configuration conf, int numExecutors, long memoryAvailable, String[] localDirs) { + super(TezTestService.class.getSimpleName()); + this.numExecutors = numExecutors; + this.localDirs = localDirs; + + long memoryAvailableBytes = memoryAvailable; + long jvmMax = Runtime.getRuntime().maxMemory(); + + LOG.info(TezTestService.class.getSimpleName() + " created with the following configuration: " + + "numExecutors=" + numExecutors + + ", workDirs=" + Arrays.toString(localDirs) + + ", memoryAvailable=" + memoryAvailable + + ", jvmMaxMemory=" + jvmMax); + + Preconditions.checkArgument(this.numExecutors > 0); + Preconditions.checkArgument(this.localDirs != null && this.localDirs.length > 0, + "Work dirs must be specified"); + Preconditions.checkState(jvmMax >= memoryAvailableBytes, + "Invalid configuration. Xmx value too small. maxAvailable=" + jvmMax + ", configured=" + + memoryAvailableBytes); + + this.shuffleHandlerConf = new Configuration(conf); + // Start Shuffle on a random port + this.shuffleHandlerConf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0); + this.shuffleHandlerConf.set(ShuffleHandler.SHUFFLE_HANDLER_LOCAL_DIRS, StringUtils.arrayToString(localDirs)); + + this.server = new TezTestServiceProtocolServerImpl(this, address); + this.containerRunner = new ContainerRunnerImpl(numExecutors, localDirs, address, + memoryAvailableBytes); + } + + @Override + public void serviceInit(Configuration conf) { + server.init(conf); + containerRunner.init(conf); + } + + @Override + public void serviceStart() throws Exception { + ShuffleHandler.initializeAndStart(shuffleHandlerConf); + containerRunner.setShufflePort(ShuffleHandler.get().getPort()); + server.start(); + containerRunner.start(); + } + + public void serviceStop() throws Exception { + containerRunner.stop(); + server.stop(); + ShuffleHandler.get().stop(); + } + + public InetSocketAddress getListenerAddress() { + return server.getBindAddress(); + } + + public int getShufflePort() { + return ShuffleHandler.get().getPort(); + } + + + + @Override + public void queueContainer(RunContainerRequestProto request) throws IOException { + numSubmissions.incrementAndGet(); + containerRunner.queueContainer(request); + } + + @Override + public void submitWork(TezTestServiceProtocolProtos.SubmitWorkRequestProto request) throws + IOException { + numSubmissions.incrementAndGet(); + containerRunner.submitWork(request); + } + + public int getNumSubmissions() { + return numSubmissions.get(); + } +} http://git-wip-us.apache.org/repos/asf/tez/blob/36e7f854/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/TezTestServiceProtocolClientImpl.java ---------------------------------------------------------------------- diff --git a/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/TezTestServiceProtocolClientImpl.java b/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/TezTestServiceProtocolClientImpl.java new file mode 100644 index 0000000..10d2952 --- /dev/null +++ b/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/TezTestServiceProtocolClientImpl.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.service.impl; + +import java.io.IOException; +import java.net.InetSocketAddress; + +import com.google.protobuf.RpcController; +import com.google.protobuf.ServiceException; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.ipc.ProtobufRpcEngine; +import org.apache.hadoop.ipc.RPC; +import org.apache.hadoop.net.NetUtils; +import org.apache.tez.service.TezTestServiceProtocolBlockingPB; +import org.apache.tez.test.service.rpc.TezTestServiceProtocolProtos; +import org.apache.tez.test.service.rpc.TezTestServiceProtocolProtos.RunContainerRequestProto; +import org.apache.tez.test.service.rpc.TezTestServiceProtocolProtos.RunContainerResponseProto; + + +public class TezTestServiceProtocolClientImpl implements TezTestServiceProtocolBlockingPB { + + private final Configuration conf; + private final InetSocketAddress serverAddr; + TezTestServiceProtocolBlockingPB proxy; + + + public TezTestServiceProtocolClientImpl(Configuration conf, String hostname, int port) { + this.conf = conf; + this.serverAddr = NetUtils.createSocketAddr(hostname, port); + } + + @Override + public RunContainerResponseProto runContainer(RpcController controller, + RunContainerRequestProto request) throws + ServiceException { + try { + return getProxy().runContainer(null, request); + } catch (IOException e) { + throw new ServiceException(e); + } + } + + @Override + public TezTestServiceProtocolProtos.SubmitWorkResponseProto submitWork(RpcController controller, + TezTestServiceProtocolProtos.SubmitWorkRequestProto request) throws + ServiceException { + try { + return getProxy().submitWork(null, request); + } catch (IOException e) { + throw new ServiceException(e); + } + } + + + public TezTestServiceProtocolBlockingPB getProxy() throws IOException { + if (proxy == null) { + proxy = createProxy(); + } + return proxy; + } + + public TezTestServiceProtocolBlockingPB createProxy() throws IOException { + TezTestServiceProtocolBlockingPB p; + // TODO Fix security + RPC.setProtocolEngine(conf, TezTestServiceProtocolBlockingPB.class, ProtobufRpcEngine.class); + p = (TezTestServiceProtocolBlockingPB) RPC + .getProxy(TezTestServiceProtocolBlockingPB.class, 0, serverAddr, conf); + return p; + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/tez/blob/36e7f854/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/TezTestServiceProtocolServerImpl.java ---------------------------------------------------------------------- diff --git a/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/TezTestServiceProtocolServerImpl.java b/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/TezTestServiceProtocolServerImpl.java new file mode 100644 index 0000000..d7f8444 --- /dev/null +++ b/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/TezTestServiceProtocolServerImpl.java @@ -0,0 +1,133 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.service.impl; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.concurrent.atomic.AtomicReference; + +import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.BlockingService; +import com.google.protobuf.RpcController; +import com.google.protobuf.ServiceException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.ipc.ProtobufRpcEngine; +import org.apache.hadoop.ipc.RPC; +import org.apache.hadoop.net.NetUtils; +import org.apache.hadoop.service.AbstractService; +import org.apache.tez.service.ContainerRunner; +import org.apache.tez.service.TezTestServiceProtocolBlockingPB; +import org.apache.tez.test.service.rpc.TezTestServiceProtocolProtos; +import org.apache.tez.test.service.rpc.TezTestServiceProtocolProtos.RunContainerRequestProto; +import org.apache.tez.test.service.rpc.TezTestServiceProtocolProtos.RunContainerResponseProto; +import org.apache.tez.test.service.rpc.TezTestServiceProtocolProtos.SubmitWorkResponseProto; + +public class TezTestServiceProtocolServerImpl extends AbstractService + implements TezTestServiceProtocolBlockingPB { + + private static final Log LOG = LogFactory.getLog(TezTestServiceProtocolServerImpl.class); + + private final ContainerRunner containerRunner; + private RPC.Server server; + private final AtomicReference<InetSocketAddress> bindAddress; + + + public TezTestServiceProtocolServerImpl(ContainerRunner containerRunner, + AtomicReference<InetSocketAddress> address) { + super(TezTestServiceProtocolServerImpl.class.getSimpleName()); + this.containerRunner = containerRunner; + this.bindAddress = address; + } + + @Override + public RunContainerResponseProto runContainer(RpcController controller, + RunContainerRequestProto request) throws + ServiceException { + LOG.info("Received request: " + request); + try { + containerRunner.queueContainer(request); + } catch (IOException e) { + throw new ServiceException(e); + } + return RunContainerResponseProto.getDefaultInstance(); + } + + @Override + public SubmitWorkResponseProto submitWork(RpcController controller, TezTestServiceProtocolProtos.SubmitWorkRequestProto request) throws + ServiceException { + LOG.info("Received submitWork request: " + request); + try { + containerRunner.submitWork(request); + } catch (IOException e) { + e.printStackTrace(); + } + return SubmitWorkResponseProto.getDefaultInstance(); + } + + + @Override + public void serviceStart() { + Configuration conf = getConfig(); + + int numHandlers = 3; + InetSocketAddress addr = new InetSocketAddress(0); + + try { + server = createServer(TezTestServiceProtocolBlockingPB.class, addr, conf, numHandlers, + TezTestServiceProtocolProtos.TezTestServiceProtocol.newReflectiveBlockingService(this)); + server.start(); + } catch (IOException e) { + LOG.error("Failed to run RPC Server", e); + throw new RuntimeException(e); + } + + InetSocketAddress serverBindAddress = NetUtils.getConnectAddress(server); + this.bindAddress.set(NetUtils.createSocketAddrForHost( + serverBindAddress.getAddress().getCanonicalHostName(), + serverBindAddress.getPort())); + LOG.info("Instantiated TestTestServiceListener at " + bindAddress); + } + + @Override + public void serviceStop() { + if (server != null) { + server.stop(); + } + } + + @InterfaceAudience.Private + @VisibleForTesting + InetSocketAddress getBindAddress() { + return this.bindAddress.get(); + } + + private RPC.Server createServer(Class<?> pbProtocol, InetSocketAddress addr, Configuration conf, + int numHandlers, BlockingService blockingService) throws + IOException { + RPC.setProtocolEngine(conf, pbProtocol, ProtobufRpcEngine.class); + RPC.Server server = new RPC.Builder(conf) + .setProtocol(pbProtocol) + .setInstance(blockingService) + .setBindAddress(addr.getHostName()) + .setPort(0) + .setNumHandlers(numHandlers) + .build(); + // TODO Add security. + return server; + } +} http://git-wip-us.apache.org/repos/asf/tez/blob/36e7f854/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/FadvisedChunkedFile.java ---------------------------------------------------------------------- diff --git a/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/FadvisedChunkedFile.java b/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/FadvisedChunkedFile.java new file mode 100644 index 0000000..65588fe --- /dev/null +++ b/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/FadvisedChunkedFile.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.shufflehandler; + +import java.io.FileDescriptor; +import java.io.IOException; +import java.io.RandomAccessFile; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.io.ReadaheadPool; +import org.apache.hadoop.io.ReadaheadPool.ReadaheadRequest; +import org.apache.hadoop.io.nativeio.NativeIO; +import org.jboss.netty.handler.stream.ChunkedFile; + +public class FadvisedChunkedFile extends ChunkedFile { + + private static final Log LOG = LogFactory.getLog(FadvisedChunkedFile.class); + + private final boolean manageOsCache; + private final int readaheadLength; + private final ReadaheadPool readaheadPool; + private final FileDescriptor fd; + private final String identifier; + + private ReadaheadRequest readaheadRequest; + + public FadvisedChunkedFile(RandomAccessFile file, long position, long count, + int chunkSize, boolean manageOsCache, int readaheadLength, + ReadaheadPool readaheadPool, String identifier) throws IOException { + super(file, position, count, chunkSize); + this.manageOsCache = manageOsCache; + this.readaheadLength = readaheadLength; + this.readaheadPool = readaheadPool; + this.fd = file.getFD(); + this.identifier = identifier; + } + + @Override + public Object nextChunk() throws Exception { + if (manageOsCache && readaheadPool != null) { + readaheadRequest = readaheadPool + .readaheadStream(identifier, fd, getCurrentOffset(), readaheadLength, + getEndOffset(), readaheadRequest); + } + return super.nextChunk(); + } + + @Override + public void close() throws Exception { + if (readaheadRequest != null) { + readaheadRequest.cancel(); + } + if (manageOsCache && getEndOffset() - getStartOffset() > 0) { + try { + NativeIO.POSIX.getCacheManipulator().posixFadviseIfPossible(identifier, + fd, + getStartOffset(), getEndOffset() - getStartOffset(), + NativeIO.POSIX.POSIX_FADV_DONTNEED); + } catch (Throwable t) { + LOG.warn("Failed to manage OS cache for " + identifier, t); + } + } + super.close(); + } +} http://git-wip-us.apache.org/repos/asf/tez/blob/36e7f854/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/FadvisedFileRegion.java ---------------------------------------------------------------------- diff --git a/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/FadvisedFileRegion.java b/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/FadvisedFileRegion.java new file mode 100644 index 0000000..bdffe52 --- /dev/null +++ b/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/FadvisedFileRegion.java @@ -0,0 +1,160 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.shufflehandler; + +import java.io.FileDescriptor; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.io.ReadaheadPool; +import org.apache.hadoop.io.ReadaheadPool.ReadaheadRequest; +import org.apache.hadoop.io.nativeio.NativeIO; +import org.jboss.netty.channel.DefaultFileRegion; + +public class FadvisedFileRegion extends DefaultFileRegion { + + private static final Log LOG = LogFactory.getLog(FadvisedFileRegion.class); + + private final boolean manageOsCache; + private final int readaheadLength; + private final ReadaheadPool readaheadPool; + private final FileDescriptor fd; + private final String identifier; + private final long count; + private final long position; + private final int shuffleBufferSize; + private final boolean shuffleTransferToAllowed; + private final FileChannel fileChannel; + + private ReadaheadRequest readaheadRequest; + + public FadvisedFileRegion(RandomAccessFile file, long position, long count, + boolean manageOsCache, int readaheadLength, ReadaheadPool readaheadPool, + String identifier, int shuffleBufferSize, + boolean shuffleTransferToAllowed) throws IOException { + super(file.getChannel(), position, count); + this.manageOsCache = manageOsCache; + this.readaheadLength = readaheadLength; + this.readaheadPool = readaheadPool; + this.fd = file.getFD(); + this.identifier = identifier; + this.fileChannel = file.getChannel(); + this.count = count; + this.position = position; + this.shuffleBufferSize = shuffleBufferSize; + this.shuffleTransferToAllowed = shuffleTransferToAllowed; + } + + @Override + public long transferTo(WritableByteChannel target, long position) + throws IOException { + if (manageOsCache && readaheadPool != null) { + readaheadRequest = readaheadPool.readaheadStream(identifier, fd, + getPosition() + position, readaheadLength, + getPosition() + getCount(), readaheadRequest); + } + + if(this.shuffleTransferToAllowed) { + return super.transferTo(target, position); + } else { + return customShuffleTransfer(target, position); + } + } + + /** + * This method transfers data using local buffer. It transfers data from + * a disk to a local buffer in memory, and then it transfers data from the + * buffer to the target. This is used only if transferTo is disallowed in + * the configuration file. super.TransferTo does not perform well on Windows + * due to a small IO request generated. customShuffleTransfer can control + * the size of the IO requests by changing the size of the intermediate + * buffer. + */ + @VisibleForTesting + long customShuffleTransfer(WritableByteChannel target, long position) + throws IOException { + long actualCount = this.count - position; + if (actualCount < 0 || position < 0) { + throw new IllegalArgumentException( + "position out of range: " + position + + " (expected: 0 - " + (this.count - 1) + ')'); + } + if (actualCount == 0) { + return 0L; + } + + long trans = actualCount; + int readSize; + ByteBuffer byteBuffer = ByteBuffer.allocate(this.shuffleBufferSize); + + while(trans > 0L && + (readSize = fileChannel.read(byteBuffer, this.position+position)) > 0) { + //adjust counters and buffer limit + if(readSize < trans) { + trans -= readSize; + position += readSize; + byteBuffer.flip(); + } else { + //We can read more than we need if the actualCount is not multiple + //of the byteBuffer size and file is big enough. In that case we cannot + //use flip method but we need to set buffer limit manually to trans. + byteBuffer.limit((int)trans); + byteBuffer.position(0); + position += trans; + trans = 0; + } + + //write data to the target + while(byteBuffer.hasRemaining()) { + target.write(byteBuffer); + } + + byteBuffer.clear(); + } + + return actualCount - trans; + } + + + @Override + public void releaseExternalResources() { + if (readaheadRequest != null) { + readaheadRequest.cancel(); + } + super.releaseExternalResources(); + } + + /** + * Call when the transfer completes successfully so we can advise the OS that + * we don't need the region to be cached anymore. + */ + public void transferSuccessful() { + if (manageOsCache && getCount() > 0) { + try { + NativeIO.POSIX.getCacheManipulator().posixFadviseIfPossible(identifier, + fd, getPosition(), getCount(), + NativeIO.POSIX.POSIX_FADV_DONTNEED); + } catch (Throwable t) { + LOG.warn("Failed to manage OS cache for " + identifier, t); + } + } + } +} http://git-wip-us.apache.org/repos/asf/tez/blob/36e7f854/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/IndexCache.java ---------------------------------------------------------------------- diff --git a/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/IndexCache.java b/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/IndexCache.java new file mode 100644 index 0000000..9a51ca0 --- /dev/null +++ b/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/IndexCache.java @@ -0,0 +1,199 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tez.shufflehandler; + +import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.tez.runtime.library.common.Constants; +import org.apache.tez.runtime.library.common.sort.impl.TezIndexRecord; +import org.apache.tez.runtime.library.common.sort.impl.TezSpillRecord; + +class IndexCache { + + private final Configuration conf; + private final int totalMemoryAllowed; + private AtomicInteger totalMemoryUsed = new AtomicInteger(); + private static final Log LOG = LogFactory.getLog(IndexCache.class); + + private final ConcurrentHashMap<String,IndexInformation> cache = + new ConcurrentHashMap<String,IndexInformation>(); + + private final LinkedBlockingQueue<String> queue = + new LinkedBlockingQueue<String>(); + + public IndexCache(Configuration conf) { + this.conf = conf; + totalMemoryAllowed = 10 * 1024 * 1024; + LOG.info("IndexCache created with max memory = " + totalMemoryAllowed); + } + + /** + * This method gets the index information for the given mapId and reduce. + * It reads the index file into cache if it is not already present. + * @param mapId + * @param reduce + * @param fileName The file to read the index information from if it is not + * already present in the cache + * @param expectedIndexOwner The expected owner of the index file + * @return The Index Information + * @throws IOException + */ + public TezIndexRecord getIndexInformation(String mapId, int reduce, + Path fileName, String expectedIndexOwner) + throws IOException { + + IndexInformation info = cache.get(mapId); + + if (info == null) { + info = readIndexFileToCache(fileName, mapId, expectedIndexOwner); + } else { + synchronized(info) { + while (isUnderConstruction(info)) { + try { + info.wait(); + } catch (InterruptedException e) { + throw new IOException("Interrupted waiting for construction", e); + } + } + } + LOG.debug("IndexCache HIT: MapId " + mapId + " found"); + } + + if (info.mapSpillRecord.size() == 0 || + info.mapSpillRecord.size() <= reduce) { + throw new IOException("Invalid request " + + " Map Id = " + mapId + " Reducer = " + reduce + + " Index Info Length = " + info.mapSpillRecord.size()); + } + return info.mapSpillRecord.getIndex(reduce); + } + + private boolean isUnderConstruction(IndexInformation info) { + synchronized(info) { + return (null == info.mapSpillRecord); + } + } + + private IndexInformation readIndexFileToCache(Path indexFileName, + String mapId, + String expectedIndexOwner) + throws IOException { + IndexInformation info; + IndexInformation newInd = new IndexInformation(); + if ((info = cache.putIfAbsent(mapId, newInd)) != null) { + synchronized(info) { + while (isUnderConstruction(info)) { + try { + info.wait(); + } catch (InterruptedException e) { + throw new IOException("Interrupted waiting for construction", e); + } + } + } + LOG.debug("IndexCache HIT: MapId " + mapId + " found"); + return info; + } + LOG.debug("IndexCache MISS: MapId " + mapId + " not found") ; + TezSpillRecord tmp = null; + try { + tmp = new TezSpillRecord(indexFileName, conf, expectedIndexOwner); + } catch (Throwable e) { + tmp = new TezSpillRecord(0); + cache.remove(mapId); + throw new IOException("Error Reading IndexFile", e); + } finally { + synchronized (newInd) { + newInd.mapSpillRecord = tmp; + newInd.notifyAll(); + } + } + queue.add(mapId); + + if (totalMemoryUsed.addAndGet(newInd.getSize()) > totalMemoryAllowed) { + freeIndexInformation(); + } + return newInd; + } + + /** + * This method removes the map from the cache if index information for this + * map is loaded(size>0), index information entry in cache will not be + * removed if it is in the loading phrase(size=0), this prevents corruption + * of totalMemoryUsed. It should be called when a map output on this tracker + * is discarded. + * @param mapId The taskID of this map. + */ + public void removeMap(String mapId) { + IndexInformation info = cache.get(mapId); + if (info == null || ((info != null) && isUnderConstruction(info))) { + return; + } + info = cache.remove(mapId); + if (info != null) { + totalMemoryUsed.addAndGet(-info.getSize()); + if (!queue.remove(mapId)) { + LOG.warn("Map ID" + mapId + " not found in queue!!"); + } + } else { + LOG.info("Map ID " + mapId + " not found in cache"); + } + } + + /** + * This method checks if cache and totolMemoryUsed is consistent. + * It is only used for unit test. + * @return True if cache and totolMemoryUsed is consistent + */ + boolean checkTotalMemoryUsed() { + int totalSize = 0; + for (IndexInformation info : cache.values()) { + totalSize += info.getSize(); + } + return totalSize == totalMemoryUsed.get(); + } + + /** + * Bring memory usage below totalMemoryAllowed. + */ + private synchronized void freeIndexInformation() { + while (totalMemoryUsed.get() > totalMemoryAllowed) { + String s = queue.remove(); + IndexInformation info = cache.remove(s); + if (info != null) { + totalMemoryUsed.addAndGet(-info.getSize()); + } + } + } + + private static class IndexInformation { + TezSpillRecord mapSpillRecord; + + int getSize() { + return mapSpillRecord == null + ? 0 + : mapSpillRecord.size() * Constants.MAP_OUTPUT_INDEX_RECORD_LENGTH; + } + } +} http://git-wip-us.apache.org/repos/asf/tez/blob/36e7f854/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/ShuffleHandler.java ---------------------------------------------------------------------- diff --git a/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/ShuffleHandler.java b/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/ShuffleHandler.java new file mode 100644 index 0000000..cc82d74 --- /dev/null +++ b/tez-ext-service-tests/src/test/java/org/apache/tez/shufflehandler/ShuffleHandler.java @@ -0,0 +1,840 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.shufflehandler; + +import static org.jboss.netty.buffer.ChannelBuffers.wrappedBuffer; +import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE; +import static org.jboss.netty.handler.codec.http.HttpMethod.GET; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.OK; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED; +import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1; + +import javax.crypto.SecretKey; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.net.InetSocketAddress; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.regex.Pattern; + +import com.google.common.base.Charsets; +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.LocalDirAllocator; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DataInputByteBuffer; +import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.hadoop.io.ReadaheadPool; +import org.apache.hadoop.io.SecureIOUtils; +import org.apache.hadoop.metrics2.annotation.Metric; +import org.apache.hadoop.metrics2.annotation.Metrics; +import org.apache.hadoop.metrics2.lib.MutableCounterInt; +import org.apache.hadoop.metrics2.lib.MutableCounterLong; +import org.apache.hadoop.metrics2.lib.MutableGaugeInt; +import org.apache.hadoop.security.ssl.SSLFactory; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.util.Shell; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.util.ConverterUtils; +import org.apache.tez.common.security.JobTokenIdentifier; +import org.apache.tez.common.security.JobTokenSecretManager; +import org.apache.tez.runtime.library.common.security.SecureShuffleUtils; +import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.ShuffleHeader; +import org.apache.tez.runtime.library.common.sort.impl.TezIndexRecord; +import org.jboss.netty.bootstrap.ServerBootstrap; +import org.jboss.netty.buffer.ChannelBuffers; +import org.jboss.netty.channel.Channel; +import org.jboss.netty.channel.ChannelFactory; +import org.jboss.netty.channel.ChannelFuture; +import org.jboss.netty.channel.ChannelFutureListener; +import org.jboss.netty.channel.ChannelHandlerContext; +import org.jboss.netty.channel.ChannelPipeline; +import org.jboss.netty.channel.ChannelPipelineFactory; +import org.jboss.netty.channel.ChannelStateEvent; +import org.jboss.netty.channel.Channels; +import org.jboss.netty.channel.ExceptionEvent; +import org.jboss.netty.channel.MessageEvent; +import org.jboss.netty.channel.SimpleChannelUpstreamHandler; +import org.jboss.netty.channel.group.ChannelGroup; +import org.jboss.netty.channel.group.DefaultChannelGroup; +import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory; +import org.jboss.netty.handler.codec.frame.TooLongFrameException; +import org.jboss.netty.handler.codec.http.DefaultHttpResponse; +import org.jboss.netty.handler.codec.http.HttpChunkAggregator; +import org.jboss.netty.handler.codec.http.HttpHeaders; +import org.jboss.netty.handler.codec.http.HttpRequest; +import org.jboss.netty.handler.codec.http.HttpRequestDecoder; +import org.jboss.netty.handler.codec.http.HttpResponse; +import org.jboss.netty.handler.codec.http.HttpResponseEncoder; +import org.jboss.netty.handler.codec.http.HttpResponseStatus; +import org.jboss.netty.handler.codec.http.QueryStringDecoder; +import org.jboss.netty.handler.ssl.SslHandler; +import org.jboss.netty.handler.stream.ChunkedWriteHandler; +import org.jboss.netty.util.CharsetUtil; + +public class ShuffleHandler { + + private static final Log LOG = LogFactory.getLog(ShuffleHandler.class); + + public static final String SHUFFLE_HANDLER_LOCAL_DIRS = "tez.shuffle.handler.local-dirs"; + + public static final String SHUFFLE_MANAGE_OS_CACHE = "mapreduce.shuffle.manage.os.cache"; + public static final boolean DEFAULT_SHUFFLE_MANAGE_OS_CACHE = true; + + public static final String SHUFFLE_READAHEAD_BYTES = "mapreduce.shuffle.readahead.bytes"; + public static final int DEFAULT_SHUFFLE_READAHEAD_BYTES = 4 * 1024 * 1024; + + // pattern to identify errors related to the client closing the socket early + // idea borrowed from Netty SslHandler + private static final Pattern IGNORABLE_ERROR_MESSAGE = Pattern.compile( + "^.*(?:connection.*reset|connection.*closed|broken.*pipe).*$", + Pattern.CASE_INSENSITIVE); + + private int port; + private final ChannelFactory selector; + private final ChannelGroup accepted = new DefaultChannelGroup(); + protected HttpPipelineFactory pipelineFact; + private final int sslFileBufferSize; + private final Configuration conf; + + private final ConcurrentMap<String, Boolean> registeredApps = new ConcurrentHashMap<String, Boolean>(); + + /** + * Should the shuffle use posix_fadvise calls to manage the OS cache during + * sendfile + */ + private final boolean manageOsCache; + private final int readaheadLength; + private final int maxShuffleConnections; + private final int shuffleBufferSize; + private final boolean shuffleTransferToAllowed; + private final ReadaheadPool readaheadPool = ReadaheadPool.getInstance(); + + private Map<String,String> userRsrc; + private JobTokenSecretManager secretManager; + + // TODO Fix this for tez. + public static final String MAPREDUCE_SHUFFLE_SERVICEID = + "mapreduce_shuffle"; + + public static final String SHUFFLE_PORT_CONFIG_KEY = "tez.shuffle.port"; + public static final int DEFAULT_SHUFFLE_PORT = 15551; + + // TODO Change configs to remove mapreduce references. + public static final String SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED = + "mapreduce.shuffle.connection-keep-alive.enable"; + public static final boolean DEFAULT_SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED = false; + + public static final String SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT = + "mapreduce.shuffle.connection-keep-alive.timeout"; + public static final int DEFAULT_SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT = 5; //seconds + + public static final String SHUFFLE_MAPOUTPUT_META_INFO_CACHE_SIZE = + "mapreduce.shuffle.mapoutput-info.meta.cache.size"; + public static final int DEFAULT_SHUFFLE_MAPOUTPUT_META_INFO_CACHE_SIZE = + 1000; + + public static final String CONNECTION_CLOSE = "close"; + + public static final String SUFFLE_SSL_FILE_BUFFER_SIZE_KEY = + "mapreduce.shuffle.ssl.file.buffer.size"; + + public static final int DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE = 60 * 1024; + + public static final String MAX_SHUFFLE_CONNECTIONS = "mapreduce.shuffle.max.connections"; + public static final int DEFAULT_MAX_SHUFFLE_CONNECTIONS = 0; // 0 implies no limit + + public static final String MAX_SHUFFLE_THREADS = "mapreduce.shuffle.max.threads"; + // 0 implies Netty default of 2 * number of available processors + public static final int DEFAULT_MAX_SHUFFLE_THREADS = 0; + + public static final String SHUFFLE_BUFFER_SIZE = + "mapreduce.shuffle.transfer.buffer.size"; + public static final int DEFAULT_SHUFFLE_BUFFER_SIZE = 128 * 1024; + + public static final String SHUFFLE_TRANSFERTO_ALLOWED = + "mapreduce.shuffle.transferTo.allowed"; + public static final boolean DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED = true; + public static final boolean WINDOWS_DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED = + false; + + final boolean connectionKeepAliveEnabled; + final int connectionKeepAliveTimeOut; + final int mapOutputMetaInfoCacheSize; + private static final AtomicBoolean started = new AtomicBoolean(false); + private static final AtomicBoolean initing = new AtomicBoolean(false); + private static ShuffleHandler INSTANCE; + + @Metrics(about="Shuffle output metrics", context="mapred") + static class ShuffleMetrics implements ChannelFutureListener { + @Metric("Shuffle output in bytes") + MutableCounterLong shuffleOutputBytes; + @Metric("# of failed shuffle outputs") + MutableCounterInt shuffleOutputsFailed; + @Metric("# of succeeeded shuffle outputs") + MutableCounterInt shuffleOutputsOK; + @Metric("# of current shuffle connections") + MutableGaugeInt shuffleConnections; + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + shuffleOutputsOK.incr(); + } else { + shuffleOutputsFailed.incr(); + } + shuffleConnections.decr(); + } + } + + public ShuffleHandler(Configuration conf) { + this.conf = conf; + manageOsCache = conf.getBoolean(SHUFFLE_MANAGE_OS_CACHE, + DEFAULT_SHUFFLE_MANAGE_OS_CACHE); + + readaheadLength = conf.getInt(SHUFFLE_READAHEAD_BYTES, + DEFAULT_SHUFFLE_READAHEAD_BYTES); + + maxShuffleConnections = conf.getInt(MAX_SHUFFLE_CONNECTIONS, + DEFAULT_MAX_SHUFFLE_CONNECTIONS); + int maxShuffleThreads = conf.getInt(MAX_SHUFFLE_THREADS, + DEFAULT_MAX_SHUFFLE_THREADS); + if (maxShuffleThreads == 0) { + maxShuffleThreads = 2 * Runtime.getRuntime().availableProcessors(); + } + + shuffleBufferSize = conf.getInt(SHUFFLE_BUFFER_SIZE, + DEFAULT_SHUFFLE_BUFFER_SIZE); + + shuffleTransferToAllowed = conf.getBoolean(SHUFFLE_TRANSFERTO_ALLOWED, + (Shell.WINDOWS)?WINDOWS_DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED: + DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED); + + ThreadFactory bossFactory = new ThreadFactoryBuilder() + .setNameFormat("ShuffleHandler Netty Boss #%d") + .build(); + ThreadFactory workerFactory = new ThreadFactoryBuilder() + .setNameFormat("ShuffleHandler Netty Worker #%d") + .build(); + + selector = new NioServerSocketChannelFactory( + Executors.newCachedThreadPool(bossFactory), + Executors.newCachedThreadPool(workerFactory), + maxShuffleThreads); + + sslFileBufferSize = conf.getInt(SUFFLE_SSL_FILE_BUFFER_SIZE_KEY, + DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE); + connectionKeepAliveEnabled = + conf.getBoolean(SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED, + DEFAULT_SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED); + connectionKeepAliveTimeOut = + Math.max(1, conf.getInt(SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT, + DEFAULT_SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT)); + mapOutputMetaInfoCacheSize = + Math.max(1, conf.getInt(SHUFFLE_MAPOUTPUT_META_INFO_CACHE_SIZE, + DEFAULT_SHUFFLE_MAPOUTPUT_META_INFO_CACHE_SIZE)); + + userRsrc = new ConcurrentHashMap<String,String>(); + secretManager = new JobTokenSecretManager(); + } + + + public void start() throws Exception { + ServerBootstrap bootstrap = new ServerBootstrap(selector); + try { + pipelineFact = new HttpPipelineFactory(conf); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + bootstrap.setPipelineFactory(pipelineFact); + port = conf.getInt(SHUFFLE_PORT_CONFIG_KEY, DEFAULT_SHUFFLE_PORT); + Channel ch = bootstrap.bind(new InetSocketAddress(port)); + accepted.add(ch); + port = ((InetSocketAddress)ch.getLocalAddress()).getPort(); + conf.set(SHUFFLE_PORT_CONFIG_KEY, Integer.toString(port)); + pipelineFact.SHUFFLE.setPort(port); + LOG.info("TezShuffleHandler" + " listening on port " + port); + } + + public static void initializeAndStart(Configuration conf) throws Exception { + if (!initing.getAndSet(true)) { + INSTANCE = new ShuffleHandler(conf); + INSTANCE.start(); + started.set(true); + } + } + + public static ShuffleHandler get() { + Preconditions.checkState(started.get(), "ShuffleHandler must be started before invoking started"); + return INSTANCE; + } + + /** + * Serialize the shuffle port into a ByteBuffer for use later on. + * @param port the port to be sent to the ApplciationMaster + * @return the serialized form of the port. + */ + public static ByteBuffer serializeMetaData(int port) throws IOException { + //TODO these bytes should be versioned + DataOutputBuffer port_dob = new DataOutputBuffer(); + port_dob.writeInt(port); + return ByteBuffer.wrap(port_dob.getData(), 0, port_dob.getLength()); + } + + /** + * A helper function to deserialize the metadata returned by ShuffleHandler. + * @param meta the metadata returned by the ShuffleHandler + * @return the port the Shuffle Handler is listening on to serve shuffle data. + */ + public static int deserializeMetaData(ByteBuffer meta) throws IOException { + //TODO this should be returning a class not just an int + DataInputByteBuffer in = new DataInputByteBuffer(); + in.reset(meta); + int port = in.readInt(); + return port; + } + + /** + * A helper function to serialize the JobTokenIdentifier to be sent to the + * ShuffleHandler as ServiceData. + * @param jobToken the job token to be used for authentication of + * shuffle data requests. + * @return the serialized version of the jobToken. + */ + public static ByteBuffer serializeServiceData(Token<JobTokenIdentifier> jobToken) throws IOException { + //TODO these bytes should be versioned + DataOutputBuffer jobToken_dob = new DataOutputBuffer(); + jobToken.write(jobToken_dob); + return ByteBuffer.wrap(jobToken_dob.getData(), 0, jobToken_dob.getLength()); + } + + static Token<JobTokenIdentifier> deserializeServiceData(ByteBuffer secret) throws IOException { + DataInputByteBuffer in = new DataInputByteBuffer(); + in.reset(secret); + Token<JobTokenIdentifier> jt = new Token<JobTokenIdentifier>(); + jt.readFields(in); + return jt; + } + + public int getPort() { + return port; + } + + public void registerApplication(String applicationIdString, Token<JobTokenIdentifier> appToken, + String user) { + Boolean registered = registeredApps.putIfAbsent(applicationIdString, Boolean.valueOf(true)); + if (registered == null) { + recordJobShuffleInfo(applicationIdString, user, appToken); + } + } + + public void unregisterApplication(String applicationIdString) { + removeJobShuffleInfo(applicationIdString); + } + + + public void stop() throws Exception { + accepted.close().awaitUninterruptibly(10, TimeUnit.SECONDS); + if (selector != null) { + ServerBootstrap bootstrap = new ServerBootstrap(selector); + bootstrap.releaseExternalResources(); + } + if (pipelineFact != null) { + pipelineFact.destroy(); + } + } + + protected Shuffle getShuffle(Configuration conf) { + return new Shuffle(conf); + } + + + private void addJobToken(String appIdString, String user, + Token<JobTokenIdentifier> jobToken) { + String jobIdString = appIdString.replace("application", "job"); + userRsrc.put(jobIdString, user); + secretManager.addTokenForJob(jobIdString, jobToken); + LOG.info("Added token for " + jobIdString); + } + + private void recordJobShuffleInfo(String appIdString, String user, + Token<JobTokenIdentifier> jobToken) { + addJobToken(appIdString, user, jobToken); + } + + private void removeJobShuffleInfo(String appIdString) { + secretManager.removeTokenForJob(appIdString); + userRsrc.remove(appIdString); + } + + class HttpPipelineFactory implements ChannelPipelineFactory { + + final Shuffle SHUFFLE; + private SSLFactory sslFactory; + + public HttpPipelineFactory(Configuration conf) throws Exception { + SHUFFLE = getShuffle(conf); + // TODO Setup SSL Shuffle +// if (conf.getBoolean(MRConfig.SHUFFLE_SSL_ENABLED_KEY, +// MRConfig.SHUFFLE_SSL_ENABLED_DEFAULT)) { +// LOG.info("Encrypted shuffle is enabled."); +// sslFactory = new SSLFactory(SSLFactory.Mode.SERVER, conf); +// sslFactory.init(); +// } + } + + public void destroy() { + if (sslFactory != null) { + sslFactory.destroy(); + } + } + + @Override + public ChannelPipeline getPipeline() throws Exception { + ChannelPipeline pipeline = Channels.pipeline(); + if (sslFactory != null) { + pipeline.addLast("ssl", new SslHandler(sslFactory.createSSLEngine())); + } + pipeline.addLast("decoder", new HttpRequestDecoder()); + pipeline.addLast("aggregator", new HttpChunkAggregator(1 << 16)); + pipeline.addLast("encoder", new HttpResponseEncoder()); + pipeline.addLast("chunking", new ChunkedWriteHandler()); + pipeline.addLast("shuffle", SHUFFLE); + return pipeline; + // TODO factor security manager into pipeline + // TODO factor out encode/decode to permit binary shuffle + // TODO factor out decode of index to permit alt. models + } + + } + + class Shuffle extends SimpleChannelUpstreamHandler { + + private final Configuration conf; + private final IndexCache indexCache; + private final LocalDirAllocator lDirAlloc = + new LocalDirAllocator(SHUFFLE_HANDLER_LOCAL_DIRS); + private int port; + + public Shuffle(Configuration conf) { + this.conf = conf; + indexCache = new IndexCache(conf); + this.port = conf.getInt(SHUFFLE_PORT_CONFIG_KEY, DEFAULT_SHUFFLE_PORT); + } + + public void setPort(int port) { + this.port = port; + } + + private List<String> splitMaps(List<String> mapq) { + if (null == mapq) { + return null; + } + final List<String> ret = new ArrayList<String>(); + for (String s : mapq) { + Collections.addAll(ret, s.split(",")); + } + return ret; + } + + @Override + public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent evt) + throws Exception { + if ((maxShuffleConnections > 0) && (accepted.size() >= maxShuffleConnections)) { + LOG.info(String.format("Current number of shuffle connections (%d) is " + + "greater than or equal to the max allowed shuffle connections (%d)", + accepted.size(), maxShuffleConnections)); + evt.getChannel().close(); + return; + } + accepted.add(evt.getChannel()); + super.channelOpen(ctx, evt); + + } + + @Override + public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt) + throws Exception { + HttpRequest request = (HttpRequest) evt.getMessage(); + if (request.getMethod() != GET) { + sendError(ctx, METHOD_NOT_ALLOWED); + return; + } + // Check whether the shuffle version is compatible + if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals( + request.getHeader(ShuffleHeader.HTTP_HEADER_NAME)) + || !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals( + request.getHeader(ShuffleHeader.HTTP_HEADER_VERSION))) { + sendError(ctx, "Incompatible shuffle request version", BAD_REQUEST); + } + final Map<String,List<String>> q = + new QueryStringDecoder(request.getUri()).getParameters(); + final List<String> keepAliveList = q.get("keepAlive"); + boolean keepAliveParam = false; + if (keepAliveList != null && keepAliveList.size() == 1) { + keepAliveParam = Boolean.valueOf(keepAliveList.get(0)); + if (LOG.isDebugEnabled()) { + LOG.debug("KeepAliveParam : " + keepAliveList + + " : " + keepAliveParam); + } + } + final List<String> mapIds = splitMaps(q.get("map")); + final List<String> reduceQ = q.get("reduce"); + final List<String> jobQ = q.get("job"); + if (LOG.isDebugEnabled()) { + LOG.debug("RECV: " + request.getUri() + + "\n mapId: " + mapIds + + "\n reduceId: " + reduceQ + + "\n jobId: " + jobQ + + "\n keepAlive: " + keepAliveParam); + } + + if (mapIds == null || reduceQ == null || jobQ == null) { + sendError(ctx, "Required param job, map and reduce", BAD_REQUEST); + return; + } + if (reduceQ.size() != 1 || jobQ.size() != 1) { + sendError(ctx, "Too many job/reduce parameters", BAD_REQUEST); + return; + } + int reduceId; + String jobId; + try { + reduceId = Integer.parseInt(reduceQ.get(0)); + jobId = jobQ.get(0); + } catch (NumberFormatException e) { + sendError(ctx, "Bad reduce parameter", BAD_REQUEST); + return; + } catch (IllegalArgumentException e) { + sendError(ctx, "Bad job parameter", BAD_REQUEST); + return; + } + final String reqUri = request.getUri(); + if (null == reqUri) { + // TODO? add upstream? + sendError(ctx, FORBIDDEN); + return; + } + HttpResponse response = new DefaultHttpResponse(HTTP_1_1, OK); + try { + verifyRequest(jobId, ctx, request, response, + new URL("http", "", this.port, reqUri)); + } catch (IOException e) { + LOG.warn("Shuffle failure ", e); + sendError(ctx, e.getMessage(), UNAUTHORIZED); + return; + } + + Map<String, MapOutputInfo> mapOutputInfoMap = + new HashMap<String, MapOutputInfo>(); + Channel ch = evt.getChannel(); + String user = userRsrc.get(jobId); + + // $x/$user/appcache/$appId/output/$mapId + // TODO: Once Shuffle is out of NM, this can use MR APIs to convert + // between App and Job + String outputBasePathStr = getBaseLocation(jobId, user); + + try { + populateHeaders(mapIds, outputBasePathStr, user, reduceId, request, + response, keepAliveParam, mapOutputInfoMap); + } catch(IOException e) { + ch.write(response); + LOG.error("Shuffle error in populating headers :", e); + String errorMessage = getErrorMessage(e); + sendError(ctx,errorMessage , INTERNAL_SERVER_ERROR); + return; + } + ch.write(response); + // TODO refactor the following into the pipeline + ChannelFuture lastMap = null; + for (String mapId : mapIds) { + try { + MapOutputInfo info = mapOutputInfoMap.get(mapId); + if (info == null) { + info = getMapOutputInfo(outputBasePathStr, mapId, reduceId, user); + } + lastMap = + sendMapOutput(ctx, ch, user, mapId, + reduceId, info); + if (null == lastMap) { + sendError(ctx, NOT_FOUND); + return; + } + } catch (IOException e) { + LOG.error("Shuffle error :", e); + String errorMessage = getErrorMessage(e); + sendError(ctx,errorMessage , INTERNAL_SERVER_ERROR); + return; + } + } + lastMap.addListener(ChannelFutureListener.CLOSE); + } + + private String getErrorMessage(Throwable t) { + StringBuffer sb = new StringBuffer(t.getMessage()); + while (t.getCause() != null) { + sb.append(t.getCause().getMessage()); + t = t.getCause(); + } + return sb.toString(); + } + + private final String USERCACHE_CONSTANT = "usercache"; + private final String APPCACHE_CONSTANT = "appcache"; + + private String getBaseLocation(String jobIdString, String user) { + String parts[] = jobIdString.split("_"); + Preconditions.checkArgument(parts.length == 3, "Invalid jobId. Expecting 3 parts"); + final ApplicationId appID = + ApplicationId.newInstance(Long.parseLong(parts[1]), Integer.parseInt(parts[2])); + final String baseStr = + USERCACHE_CONSTANT + "/" + user + "/" + + APPCACHE_CONSTANT + "/" + + ConverterUtils.toString(appID) + "/output" + "/"; + return baseStr; + } + + protected MapOutputInfo getMapOutputInfo(String base, String mapId, + int reduce, String user) throws IOException { + // Index file + Path indexFileName = + lDirAlloc.getLocalPathToRead(base + "/file.out.index", conf); + TezIndexRecord info = + indexCache.getIndexInformation(mapId, reduce, indexFileName, user); + + Path mapOutputFileName = + lDirAlloc.getLocalPathToRead(base + "/file.out", conf); + if (LOG.isDebugEnabled()) { + LOG.debug(base + " : " + mapOutputFileName + " : " + indexFileName); + } + MapOutputInfo outputInfo = new MapOutputInfo(mapOutputFileName, info); + return outputInfo; + } + + protected void populateHeaders(List<String> mapIds, String outputBaseStr, + String user, int reduce, HttpRequest request, HttpResponse response, + boolean keepAliveParam, Map<String, MapOutputInfo> mapOutputInfoMap) + throws IOException { + + long contentLength = 0; + for (String mapId : mapIds) { + String base = outputBaseStr + mapId; + MapOutputInfo outputInfo = getMapOutputInfo(base, mapId, reduce, user); + if (mapOutputInfoMap.size() < mapOutputMetaInfoCacheSize) { + mapOutputInfoMap.put(mapId, outputInfo); + } + // Index file + Path indexFileName = + lDirAlloc.getLocalPathToRead(base + "/file.out.index", conf); + TezIndexRecord info = + indexCache.getIndexInformation(mapId, reduce, indexFileName, user); + ShuffleHeader header = + new ShuffleHeader(mapId, info.getPartLength(), info.getRawLength(), reduce); + DataOutputBuffer dob = new DataOutputBuffer(); + header.write(dob); + + contentLength += info.getPartLength(); + contentLength += dob.getLength(); + } + + // Now set the response headers. + setResponseHeaders(response, keepAliveParam, contentLength); + } + + protected void setResponseHeaders(HttpResponse response, + boolean keepAliveParam, long contentLength) { + if (!connectionKeepAliveEnabled && !keepAliveParam) { + LOG.info("Setting connection close header..."); + response.setHeader(HttpHeaders.Names.CONNECTION, CONNECTION_CLOSE); + } else { + response.setHeader(HttpHeaders.Names.CONTENT_LENGTH, + String.valueOf(contentLength)); + response.setHeader(HttpHeaders.Names.CONNECTION, HttpHeaders.Values.KEEP_ALIVE); + response.setHeader(HttpHeaders.Values.KEEP_ALIVE, "timeout=" + + connectionKeepAliveTimeOut); + LOG.info("Content Length in shuffle : " + contentLength); + } + } + + class MapOutputInfo { + final Path mapOutputFileName; + final TezIndexRecord indexRecord; + + MapOutputInfo(Path mapOutputFileName, TezIndexRecord indexRecord) { + this.mapOutputFileName = mapOutputFileName; + this.indexRecord = indexRecord; + } + } + + protected void verifyRequest(String appid, ChannelHandlerContext ctx, + HttpRequest request, HttpResponse response, URL requestUri) + throws IOException { + SecretKey tokenSecret = secretManager.retrieveTokenSecret(appid); + if (null == tokenSecret) { + LOG.info("Request for unknown token " + appid); + throw new IOException("could not find jobid"); + } + // string to encrypt + String enc_str = SecureShuffleUtils.buildMsgFrom(requestUri); + // hash from the fetcher + String urlHashStr = + request.getHeader(SecureShuffleUtils.HTTP_HEADER_URL_HASH); + if (urlHashStr == null) { + LOG.info("Missing header hash for " + appid); + throw new IOException("fetcher cannot be authenticated"); + } + if (LOG.isDebugEnabled()) { + int len = urlHashStr.length(); + LOG.debug("verifying request. enc_str=" + enc_str + "; hash=..." + + urlHashStr.substring(len-len/2, len-1)); + } + // verify - throws exception + SecureShuffleUtils.verifyReply(urlHashStr, enc_str, tokenSecret); + // verification passed - encode the reply + String reply = + SecureShuffleUtils.generateHash(urlHashStr.getBytes(Charsets.UTF_8), + tokenSecret); + response.setHeader(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH, reply); + // Put shuffle version into http header + response.setHeader(ShuffleHeader.HTTP_HEADER_NAME, + ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); + response.setHeader(ShuffleHeader.HTTP_HEADER_VERSION, + ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); + if (LOG.isDebugEnabled()) { + int len = reply.length(); + LOG.debug("Fetcher request verfied. enc_str=" + enc_str + ";reply=" + + reply.substring(len-len/2, len-1)); + } + } + + protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, Channel ch, + String user, String mapId, int reduce, MapOutputInfo mapOutputInfo) + throws IOException { + final TezIndexRecord info = mapOutputInfo.indexRecord; + final ShuffleHeader header = + new ShuffleHeader(mapId, info.getPartLength(), info.getRawLength(), reduce); + final DataOutputBuffer dob = new DataOutputBuffer(); + header.write(dob); + ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength())); + final File spillfile = + new File(mapOutputInfo.mapOutputFileName.toString()); + RandomAccessFile spill; + try { + spill = SecureIOUtils.openForRandomRead(spillfile, "r", user, null); + } catch (FileNotFoundException e) { + LOG.info(spillfile + " not found"); + return null; + } + ChannelFuture writeFuture; + if (ch.getPipeline().get(SslHandler.class) == null) { + final FadvisedFileRegion partition = new FadvisedFileRegion(spill, + info.getStartOffset(), info.getPartLength(), manageOsCache, readaheadLength, + readaheadPool, spillfile.getAbsolutePath(), + shuffleBufferSize, shuffleTransferToAllowed); + writeFuture = ch.write(partition); + writeFuture.addListener(new ChannelFutureListener() { + // TODO error handling; distinguish IO/connection failures, + // attribute to appropriate spill output + @Override + public void operationComplete(ChannelFuture future) { + if (future.isSuccess()) { + partition.transferSuccessful(); + } + partition.releaseExternalResources(); + } + }); + } else { + // HTTPS cannot be done with zero copy. + final FadvisedChunkedFile chunk = new FadvisedChunkedFile(spill, + info.getStartOffset(), info.getPartLength(), sslFileBufferSize, + manageOsCache, readaheadLength, readaheadPool, + spillfile.getAbsolutePath()); + writeFuture = ch.write(chunk); + } + return writeFuture; + } + + protected void sendError(ChannelHandlerContext ctx, + HttpResponseStatus status) { + sendError(ctx, "", status); + } + + protected void sendError(ChannelHandlerContext ctx, String message, + HttpResponseStatus status) { + HttpResponse response = new DefaultHttpResponse(HTTP_1_1, status); + response.setHeader(CONTENT_TYPE, "text/plain; charset=UTF-8"); + // Put shuffle version into http header + response.setHeader(ShuffleHeader.HTTP_HEADER_NAME, + ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); + response.setHeader(ShuffleHeader.HTTP_HEADER_VERSION, + ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); + response.setContent( + ChannelBuffers.copiedBuffer(message, CharsetUtil.UTF_8)); + + // Close the connection as soon as the error message is sent. + ctx.getChannel().write(response).addListener(ChannelFutureListener.CLOSE); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) + throws Exception { + Channel ch = e.getChannel(); + Throwable cause = e.getCause(); + if (cause instanceof TooLongFrameException) { + sendError(ctx, BAD_REQUEST); + return; + } else if (cause instanceof IOException) { + if (cause instanceof ClosedChannelException) { + LOG.debug("Ignoring closed channel error", cause); + return; + } + String message = String.valueOf(cause.getMessage()); + if (IGNORABLE_ERROR_MESSAGE.matcher(message).matches()) { + LOG.debug("Ignoring client socket close", cause); + return; + } + } + + LOG.error("Shuffle error: ", cause); + if (ch.isConnected()) { + LOG.error("Shuffle error " + e); + sendError(ctx, INTERNAL_SERVER_ERROR); + } + } + } +}
