TEZ-3355. Tez Custom Shuffle Handler POC (jeagles)
Project: http://git-wip-us.apache.org/repos/asf/tez/repo Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/077dd88e Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/077dd88e Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/077dd88e Branch: refs/heads/TEZ-3334 Commit: 077dd88e03b6e055b2c0bd8b7cb1986c7775658d Parents: 97fa44f Author: Jonathan Eagles <[email protected]> Authored: Mon Jul 25 10:29:31 2016 -0500 Committer: Jonathan Eagles <[email protected]> Committed: Mon Jul 25 10:29:31 2016 -0500 ---------------------------------------------------------------------- TEZ-3334-CHANGES.txt | 7 + pom.xml | 25 + tez-dist/src/main/assembly/tez-dist-minimal.xml | 3 + tez-dist/src/main/assembly/tez-dist.xml | 3 + tez-plugins/pom.xml | 2 + .../tez-aux-services/findbugs-exclude.xml | 16 + tez-plugins/tez-aux-services/pom.xml | 108 ++ .../org/apache/tez/auxservices/IndexCache.java | 195 +++ .../apache/tez/auxservices/ShuffleHandler.java | 1343 ++++++++++++++++++ .../tez/auxservices/TestShuffleHandler.java | 1127 +++++++++++++++ tez-plugins/tez-history-parser/pom.xml | 5 - .../tez-yarn-timeline-history-with-acls/pom.xml | 5 - 12 files changed, 2829 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/TEZ-3334-CHANGES.txt ---------------------------------------------------------------------- diff --git a/TEZ-3334-CHANGES.txt b/TEZ-3334-CHANGES.txt new file mode 100644 index 0000000..f779000 --- /dev/null +++ b/TEZ-3334-CHANGES.txt @@ -0,0 +1,7 @@ +Apache Tez Change Log +===================== + +INCOMPATIBLE CHANGES: + +ALL CHANGES: + TEZ-3355. Tez Custom Shuffle Handler POC http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/pom.xml ---------------------------------------------------------------------- diff --git a/pom.xml b/pom.xml index 336f5cb..6e4fe40 100644 --- a/pom.xml +++ b/pom.xml @@ -580,6 +580,31 @@ </dependency> <dependency> <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-server-common</artifactId> + <version>${hadoop.version}</version> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-mapreduce-client-shuffle</artifactId> + <scope>provided</scope> + <version>${hadoop.version}</version> + <exclusions> + <exclusion> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-server-common</artifactId> + </exclusion> + <exclusion> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-server-nodemanager</artifactId> + </exclusion> + <exclusion> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-mapreduce-client-common</artifactId> + </exclusion> + </exclusions> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> <artifactId>hadoop-mapreduce-client-jobclient</artifactId> <scope>test</scope> <type>test-jar</type> http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-dist/src/main/assembly/tez-dist-minimal.xml ---------------------------------------------------------------------- diff --git a/tez-dist/src/main/assembly/tez-dist-minimal.xml b/tez-dist/src/main/assembly/tez-dist-minimal.xml index 869e5b0..80633ff 100644 --- a/tez-dist/src/main/assembly/tez-dist-minimal.xml +++ b/tez-dist/src/main/assembly/tez-dist-minimal.xml @@ -22,6 +22,9 @@ <moduleSets> <moduleSet> <useAllReactorProjects>true</useAllReactorProjects> + <excludes> + <exclude>org.apache.tez:tez-aux-services</exclude> + </excludes> <binaries> <outputDirectory>/</outputDirectory> <unpack>false</unpack> http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-dist/src/main/assembly/tez-dist.xml ---------------------------------------------------------------------- diff --git a/tez-dist/src/main/assembly/tez-dist.xml b/tez-dist/src/main/assembly/tez-dist.xml index a181546..b8834a8 100644 --- a/tez-dist/src/main/assembly/tez-dist.xml +++ b/tez-dist/src/main/assembly/tez-dist.xml @@ -22,6 +22,9 @@ <moduleSets> <moduleSet> <useAllReactorProjects>true</useAllReactorProjects> + <excludes> + <exclude>org.apache.tez:tez-aux-services</exclude> + </excludes> <binaries> <outputDirectory>/</outputDirectory> <unpack>false</unpack> http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-plugins/pom.xml ---------------------------------------------------------------------- diff --git a/tez-plugins/pom.xml b/tez-plugins/pom.xml index 27707a8..ffe59b9 100644 --- a/tez-plugins/pom.xml +++ b/tez-plugins/pom.xml @@ -48,6 +48,7 @@ <module>tez-yarn-timeline-history</module> <module>tez-yarn-timeline-history-with-acls</module> <module>tez-history-parser</module> + <module>tez-aux-services</module> </modules> </profile> <profile> @@ -61,6 +62,7 @@ <module>tez-yarn-timeline-cache-plugin</module> <module>tez-yarn-timeline-history-with-fs</module> <module>tez-history-parser</module> + <module>tez-aux-services</module> </modules> </profile> http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-plugins/tez-aux-services/findbugs-exclude.xml ---------------------------------------------------------------------- diff --git a/tez-plugins/tez-aux-services/findbugs-exclude.xml b/tez-plugins/tez-aux-services/findbugs-exclude.xml new file mode 100644 index 0000000..5b11308 --- /dev/null +++ b/tez-plugins/tez-aux-services/findbugs-exclude.xml @@ -0,0 +1,16 @@ +<!-- + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. See accompanying LICENSE file. +--> +<FindBugsFilter> + +</FindBugsFilter> http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-plugins/tez-aux-services/pom.xml ---------------------------------------------------------------------- diff --git a/tez-plugins/tez-aux-services/pom.xml b/tez-plugins/tez-aux-services/pom.xml new file mode 100644 index 0000000..c30555b --- /dev/null +++ b/tez-plugins/tez-aux-services/pom.xml @@ -0,0 +1,108 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- + ~ Licensed under the Apache License, Version 2.0 (the "License"); + ~ you may not use this file except in compliance with the License. + ~ You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> + +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <artifactId>tez-plugins</artifactId> + <groupId>org.apache.tez</groupId> + <version>0.9.0-SNAPSHOT</version> + </parent> + + <artifactId>tez-aux-services</artifactId> + + <dependencies> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-log4j12</artifactId> + </dependency> + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-common</artifactId> + </dependency> + <dependency> + <groupId>junit</groupId> + <artifactId>junit</artifactId> + </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-all</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <!-- Required for the ShuffleHandler --> + <groupId>org.apache.tez</groupId> + <artifactId>tez-runtime-library</artifactId> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-server-common</artifactId> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-mapreduce-client-core</artifactId> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-mapreduce-client-shuffle</artifactId> + </dependency> + <dependency> + <groupId>org.apache.tez</groupId> + <artifactId>tez-mapreduce</artifactId> + </dependency> + <dependency> + <groupId>org.mortbay.jetty</groupId> + <artifactId>jetty</artifactId> + </dependency> + </dependencies> + + <build> + <!-- + Include all files in src/main/resources. By default, do not apply property + substitution (filtering=false), but do apply property substitution to + version-info.properties (filtering=true). This will substitute the + version information correctly, but prevent Maven from altering other files. + --> + <resources> + <resource> + <directory>${basedir}/src/main/resources</directory> + <excludes> + <exclude>tez-api-version-info.properties</exclude> + </excludes> + <filtering>false</filtering> + </resource> + <resource> + <directory>${basedir}/src/main/resources</directory> + <includes> + <include>tez-api-version-info.properties</include> + </includes> + <filtering>true</filtering> + </resource> + </resources> + <plugins> + <plugin> + <groupId>org.apache.rat</groupId> + <artifactId>apache-rat-plugin</artifactId> + </plugin> + </plugins> + </build> + +</project> http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/IndexCache.java ---------------------------------------------------------------------- diff --git a/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/IndexCache.java b/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/IndexCache.java new file mode 100644 index 0000000..532187e --- /dev/null +++ b/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/IndexCache.java @@ -0,0 +1,195 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tez.auxservices; + +import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.tez.runtime.library.common.Constants; +import org.apache.tez.runtime.library.common.sort.impl.TezIndexRecord; +import org.apache.tez.runtime.library.common.sort.impl.TezSpillRecord; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class IndexCache { + + private final Configuration conf; + private final int totalMemoryAllowed; + private AtomicInteger totalMemoryUsed = new AtomicInteger(); + private static final Logger LOG = LoggerFactory.getLogger(IndexCache.class); + + private final ConcurrentHashMap<String,IndexInformation> cache = + new ConcurrentHashMap<String,IndexInformation>(); + + private final LinkedBlockingQueue<String> queue = + new LinkedBlockingQueue<String>(); + + public IndexCache(Configuration conf) { + this.conf = conf; + totalMemoryAllowed = 10 * 1024 * 1024; + LOG.info("IndexCache created with max memory = " + totalMemoryAllowed); + } + + /** + * This method gets the index information for the given mapId and reduce. + * It reads the index file into cache if it is not already present. + * @param mapId + * @param reduce + * @param fileName The file to read the index information from if it is not + * already present in the cache + * @param expectedIndexOwner The expected owner of the index file + * @return The Index Information + * @throws IOException + */ + public TezIndexRecord getIndexInformation(String mapId, int reduce, + Path fileName, String expectedIndexOwner) + throws IOException { + + IndexInformation info = cache.get(mapId); + + if (info == null) { + info = readIndexFileToCache(fileName, mapId, expectedIndexOwner); + } else { + synchronized(info) { + while (isUnderConstruction(info)) { + try { + info.wait(); + } catch (InterruptedException e) { + throw new IOException("Interrupted waiting for construction", e); + } + } + } + LOG.debug("IndexCache HIT: MapId " + mapId + " found"); + } + + if (info.mapSpillRecord.size() == 0 || + info.mapSpillRecord.size() <= reduce) { + throw new IOException("Invalid request " + + " Map Id = " + mapId + " Reducer = " + reduce + + " Index Info Length = " + info.mapSpillRecord.size()); + } + return info.mapSpillRecord.getIndex(reduce); + } + + private boolean isUnderConstruction(IndexInformation info) { + synchronized(info) { + return (null == info.mapSpillRecord); + } + } + + private IndexInformation readIndexFileToCache(Path indexFileName, + String mapId, + String expectedIndexOwner) + throws IOException { + IndexInformation info; + IndexInformation newInd = new IndexInformation(); + if ((info = cache.putIfAbsent(mapId, newInd)) != null) { + synchronized(info) { + while (isUnderConstruction(info)) { + try { + info.wait(); + } catch (InterruptedException e) { + throw new IOException("Interrupted waiting for construction", e); + } + } + } + LOG.debug("IndexCache HIT: MapId " + mapId + " found"); + return info; + } + LOG.debug("IndexCache MISS: MapId " + mapId + " not found") ; + TezSpillRecord tmp = null; + try { + tmp = new TezSpillRecord(indexFileName, conf, expectedIndexOwner); + } catch (Throwable e) { + tmp = new TezSpillRecord(0); + cache.remove(mapId); + throw new IOException("Error Reading IndexFile", e); + } finally { + synchronized (newInd) { + newInd.mapSpillRecord = tmp; + newInd.notifyAll(); + } + } + queue.add(mapId); + + if (totalMemoryUsed.addAndGet(newInd.getSize()) > totalMemoryAllowed) { + freeIndexInformation(); + } + return newInd; + } + + /** + * This method removes the map from the cache if index information for this + * map is loaded(size>0), index information entry in cache will not be + * removed if it is in the loading phrase(size=0), this prevents corruption + * of totalMemoryUsed. It should be called when a map output on this tracker + * is discarded. + * @param mapId The taskID of this map. + */ + public void removeMap(String mapId) { + IndexInformation info = cache.get(mapId); + if (info == null || ((info != null) && isUnderConstruction(info))) { + return; + } + info = cache.remove(mapId); + if (info != null) { + totalMemoryUsed.addAndGet(-info.getSize()); + if (!queue.remove(mapId)) { + LOG.warn("Map ID" + mapId + " not found in queue!!"); + } + } else { + LOG.info("Map ID " + mapId + " not found in cache"); + } + } + + /** + * This method checks if cache and totolMemoryUsed is consistent. + * It is only used for unit test. + * @return True if cache and totolMemoryUsed is consistent + */ + boolean checkTotalMemoryUsed() { + int totalSize = 0; + for (IndexInformation info : cache.values()) { + totalSize += info.getSize(); + } + return totalSize == totalMemoryUsed.get(); + } + + /** + * Bring memory usage below totalMemoryAllowed. + */ + private synchronized void freeIndexInformation() { + while (totalMemoryUsed.get() > totalMemoryAllowed) { + String s = queue.remove(); + IndexInformation info = cache.remove(s); + if (info != null) { + totalMemoryUsed.addAndGet(-info.getSize()); + } + } + } + + private static class IndexInformation { + TezSpillRecord mapSpillRecord; + + int getSize() { + return mapSpillRecord == null + ? 0 + : mapSpillRecord.size() * Constants.MAP_OUTPUT_INDEX_RECORD_LENGTH; + } + } +} http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/ShuffleHandler.java ---------------------------------------------------------------------- diff --git a/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/ShuffleHandler.java b/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/ShuffleHandler.java new file mode 100644 index 0000000..c8eb238 --- /dev/null +++ b/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/ShuffleHandler.java @@ -0,0 +1,1343 @@ +/** +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.tez.auxservices; + +import static org.fusesource.leveldbjni.JniDBFactory.asString; +import static org.fusesource.leveldbjni.JniDBFactory.bytes; +import static org.jboss.netty.buffer.ChannelBuffers.wrappedBuffer; +import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE; +import static org.jboss.netty.handler.codec.http.HttpMethod.GET; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.OK; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED; +import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.net.InetSocketAddress; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.regex.Pattern; + +import javax.crypto.SecretKey; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.LocalDirAllocator; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DataInputByteBuffer; +import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.hadoop.io.ReadaheadPool; +import org.apache.hadoop.io.SecureIOUtils; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapred.FadvisedChunkedFile; +import org.apache.hadoop.mapred.FadvisedFileRegion; +import org.apache.hadoop.mapred.proto.ShuffleHandlerRecoveryProtos.JobShuffleInfoProto; +import org.apache.hadoop.mapreduce.JobID; +import org.apache.tez.mapreduce.hadoop.MRConfig; +import org.apache.tez.common.security.JobTokenIdentifier; +import org.apache.tez.common.security.JobTokenSecretManager; +import org.apache.tez.runtime.library.common.security.SecureShuffleUtils; +import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.ShuffleHeader; +import org.apache.hadoop.metrics2.MetricsSystem; +import org.apache.hadoop.metrics2.annotation.Metric; +import org.apache.hadoop.metrics2.annotation.Metrics; +import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem; +import org.apache.hadoop.metrics2.lib.MutableCounterInt; +import org.apache.hadoop.metrics2.lib.MutableCounterLong; +import org.apache.hadoop.metrics2.lib.MutableGaugeInt; +import org.apache.hadoop.security.proto.SecurityProtos.TokenProto; +import org.apache.hadoop.security.ssl.SSLFactory; +import org.apache.hadoop.security.token.Token; +import org.apache.tez.runtime.library.common.sort.impl.TezIndexRecord; +import org.apache.hadoop.util.Shell; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.proto.YarnServerCommonProtos.VersionProto; +import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; +import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext; +import org.apache.hadoop.yarn.server.api.AuxiliaryService; +import org.apache.hadoop.yarn.server.records.Version; +import org.apache.hadoop.yarn.server.records.impl.pb.VersionPBImpl; +import org.apache.hadoop.yarn.server.utils.LeveldbIterator; +import org.fusesource.leveldbjni.JniDBFactory; +import org.fusesource.leveldbjni.internal.NativeDB; +import org.iq80.leveldb.DB; +import org.iq80.leveldb.DBException; +import org.iq80.leveldb.Logger; +import org.iq80.leveldb.Options; +import org.jboss.netty.bootstrap.ServerBootstrap; +import org.jboss.netty.buffer.ChannelBuffers; +import org.jboss.netty.channel.Channel; +import org.jboss.netty.channel.ChannelFactory; +import org.jboss.netty.channel.ChannelFuture; +import org.jboss.netty.channel.ChannelFutureListener; +import org.jboss.netty.channel.ChannelHandlerContext; +import org.jboss.netty.channel.ChannelPipeline; +import org.jboss.netty.channel.ChannelPipelineFactory; +import org.jboss.netty.channel.ChannelStateEvent; +import org.jboss.netty.channel.Channels; +import org.jboss.netty.channel.ExceptionEvent; +import org.jboss.netty.channel.MessageEvent; +import org.jboss.netty.channel.SimpleChannelUpstreamHandler; +import org.jboss.netty.channel.group.ChannelGroup; +import org.jboss.netty.channel.group.DefaultChannelGroup; +import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory; +import org.jboss.netty.handler.codec.frame.TooLongFrameException; +import org.jboss.netty.handler.codec.http.DefaultHttpResponse; +import org.jboss.netty.handler.codec.http.HttpChunkAggregator; +import org.jboss.netty.handler.codec.http.HttpRequest; +import org.jboss.netty.handler.codec.http.HttpRequestDecoder; +import org.jboss.netty.handler.codec.http.HttpResponse; +import org.jboss.netty.handler.codec.http.HttpResponseEncoder; +import org.jboss.netty.handler.codec.http.HttpResponseStatus; +import org.jboss.netty.handler.codec.http.QueryStringDecoder; +import org.jboss.netty.handler.ssl.SslHandler; +import org.jboss.netty.handler.stream.ChunkedWriteHandler; +import org.jboss.netty.util.CharsetUtil; +import org.mortbay.jetty.HttpHeaders; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Charsets; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.cache.RemovalListener; +import com.google.common.cache.RemovalNotification; +import com.google.common.cache.Weigher; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import com.google.protobuf.ByteString; + +public class ShuffleHandler extends AuxiliaryService { + + private static final Log LOG = LogFactory.getLog(ShuffleHandler.class); + private static final Log AUDITLOG = + LogFactory.getLog(ShuffleHandler.class.getName()+".audit"); + public static final String SHUFFLE_MANAGE_OS_CACHE = "mapreduce.shuffle.manage.os.cache"; + public static final boolean DEFAULT_SHUFFLE_MANAGE_OS_CACHE = true; + + public static final String SHUFFLE_READAHEAD_BYTES = "mapreduce.shuffle.readahead.bytes"; + public static final int DEFAULT_SHUFFLE_READAHEAD_BYTES = 4 * 1024 * 1024; + public static final String USERCACHE = "usercache"; + public static final String APPCACHE = "appcache"; + + // pattern to identify errors related to the client closing the socket early + // idea borrowed from Netty SslHandler + private static final Pattern IGNORABLE_ERROR_MESSAGE = Pattern.compile( + "^.*(?:connection.*reset|connection.*closed|broken.*pipe).*$", + Pattern.CASE_INSENSITIVE); + + private static final String STATE_DB_NAME = "mapreduce_shuffle_state"; + private static final String STATE_DB_SCHEMA_VERSION_KEY = "shuffle-schema-version"; + protected static final Version CURRENT_VERSION_INFO = + Version.newInstance(1, 0); + + private static final String DATA_FILE_NAME = "file.out"; + private static final String INDEX_FILE_NAME = "file.out.index"; + + private int port; + private ChannelFactory selector; + private final ChannelGroup accepted = new DefaultChannelGroup(); + protected HttpPipelineFactory pipelineFact; + private int sslFileBufferSize; + + /** + * Should the shuffle use posix_fadvise calls to manage the OS cache during + * sendfile + */ + private boolean manageOsCache; + private int readaheadLength; + private int maxShuffleConnections; + private int shuffleBufferSize; + private boolean shuffleTransferToAllowed; + private int maxSessionOpenFiles; + private ReadaheadPool readaheadPool = ReadaheadPool.getInstance(); + + private Map<String,String> userRsrc; + private JobTokenSecretManager secretManager; + + private DB stateDb = null; + + public static final String MAPREDUCE_SHUFFLE_SERVICEID = + "mapreduce_shuffle"; + + public static final String SHUFFLE_PORT_CONFIG_KEY = "mapreduce.shuffle.port"; + public static final int DEFAULT_SHUFFLE_PORT = 13562; + + public static final String SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED = + "mapreduce.shuffle.connection-keep-alive.enable"; + public static final boolean DEFAULT_SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED = false; + + public static final String SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT = + "mapreduce.shuffle.connection-keep-alive.timeout"; + public static final int DEFAULT_SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT = 5; //seconds + + public static final String SHUFFLE_MAPOUTPUT_META_INFO_CACHE_SIZE = + "mapreduce.shuffle.mapoutput-info.meta.cache.size"; + public static final int DEFAULT_SHUFFLE_MAPOUTPUT_META_INFO_CACHE_SIZE = + 1000; + + public static final String CONNECTION_CLOSE = "close"; + + public static final String SUFFLE_SSL_FILE_BUFFER_SIZE_KEY = + "mapreduce.shuffle.ssl.file.buffer.size"; + + public static final int DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE = 60 * 1024; + + public static final String MAX_SHUFFLE_CONNECTIONS = "mapreduce.shuffle.max.connections"; + public static final int DEFAULT_MAX_SHUFFLE_CONNECTIONS = 0; // 0 implies no limit + + public static final String MAX_SHUFFLE_THREADS = "mapreduce.shuffle.max.threads"; + // 0 implies Netty default of 2 * number of available processors + public static final int DEFAULT_MAX_SHUFFLE_THREADS = 0; + + public static final String SHUFFLE_BUFFER_SIZE = + "mapreduce.shuffle.transfer.buffer.size"; + public static final int DEFAULT_SHUFFLE_BUFFER_SIZE = 128 * 1024; + + public static final String SHUFFLE_TRANSFERTO_ALLOWED = + "mapreduce.shuffle.transferTo.allowed"; + public static final boolean DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED = true; + public static final boolean WINDOWS_DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED = + false; + + /* the maximum number of files a single GET request can + open simultaneously during shuffle + */ + public static final String SHUFFLE_MAX_SESSION_OPEN_FILES = + "mapreduce.shuffle.max.session-open-files"; + public static final int DEFAULT_SHUFFLE_MAX_SESSION_OPEN_FILES = 3; + + boolean connectionKeepAliveEnabled = false; + int connectionKeepAliveTimeOut; + int mapOutputMetaInfoCacheSize; + + @Metrics(about="Shuffle output metrics", context="mapred") + static class ShuffleMetrics implements ChannelFutureListener { + @Metric("Shuffle output in bytes") + MutableCounterLong shuffleOutputBytes; + @Metric("# of failed shuffle outputs") + MutableCounterInt shuffleOutputsFailed; + @Metric("# of succeeeded shuffle outputs") + MutableCounterInt shuffleOutputsOK; + @Metric("# of current shuffle connections") + MutableGaugeInt shuffleConnections; + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + shuffleOutputsOK.incr(); + } else { + shuffleOutputsFailed.incr(); + } + shuffleConnections.decr(); + } + } + + final ShuffleMetrics metrics; + + class ReduceMapFileCount implements ChannelFutureListener { + + private ReduceContext reduceContext; + + public ReduceMapFileCount(ReduceContext rc) { + this.reduceContext = rc; + } + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + future.getChannel().close(); + return; + } + int waitCount = this.reduceContext.getMapsToWait().decrementAndGet(); + if (waitCount == 0) { + metrics.operationComplete(future); + future.getChannel().close(); + } else { + pipelineFact.getSHUFFLE().sendMap(reduceContext); + } + } + } + + /** + * Maintain parameters per messageReceived() Netty context. + * Allows sendMapOutput calls from operationComplete() + */ + private static class ReduceContext { + + private List<String> mapIds; + private AtomicInteger mapsToWait; + private AtomicInteger mapsToSend; + private int reduceId; + private ChannelHandlerContext ctx; + private String user; + private Map<String, Shuffle.MapOutputInfo> infoMap; + private String jobId; + + public ReduceContext(List<String> mapIds, int rId, + ChannelHandlerContext context, String usr, + Map<String, Shuffle.MapOutputInfo> mapOutputInfoMap, + String jobId) { + + this.mapIds = mapIds; + this.reduceId = rId; + /** + * Atomic count for tracking the no. of map outputs that are yet to + * complete. Multiple futureListeners' operationComplete() can decrement + * this value asynchronously. It is used to decide when the channel should + * be closed. + */ + this.mapsToWait = new AtomicInteger(mapIds.size()); + /** + * Atomic count for tracking the no. of map outputs that have been sent. + * Multiple sendMap() calls can increment this value + * asynchronously. Used to decide which mapId should be sent next. + */ + this.mapsToSend = new AtomicInteger(0); + this.ctx = context; + this.user = usr; + this.infoMap = mapOutputInfoMap; + this.jobId = jobId; + } + + public int getReduceId() { + return reduceId; + } + + public ChannelHandlerContext getCtx() { + return ctx; + } + + public String getUser() { + return user; + } + + public Map<String, Shuffle.MapOutputInfo> getInfoMap() { + return infoMap; + } + + public String getJobId() { + return jobId; + } + + public List<String> getMapIds() { + return mapIds; + } + + public AtomicInteger getMapsToSend() { + return mapsToSend; + } + + public AtomicInteger getMapsToWait() { + return mapsToWait; + } + } + + ShuffleHandler(MetricsSystem ms) { + super(MAPREDUCE_SHUFFLE_SERVICEID); + metrics = ms.register(new ShuffleMetrics()); + } + + public ShuffleHandler() { + this(DefaultMetricsSystem.instance()); + } + + /** + * Serialize the shuffle port into a ByteBuffer for use later on. + * @param port the port to be sent to the ApplciationMaster + * @return the serialized form of the port. + */ + public static ByteBuffer serializeMetaData(int port) throws IOException { + //TODO these bytes should be versioned + DataOutputBuffer port_dob = new DataOutputBuffer(); + port_dob.writeInt(port); + return ByteBuffer.wrap(port_dob.getData(), 0, port_dob.getLength()); + } + + /** + * A helper function to deserialize the metadata returned by ShuffleHandler. + * @param meta the metadata returned by the ShuffleHandler + * @return the port the Shuffle Handler is listening on to serve shuffle data. + */ + public static int deserializeMetaData(ByteBuffer meta) throws IOException { + //TODO this should be returning a class not just an int + DataInputByteBuffer in = new DataInputByteBuffer(); + in.reset(meta); + int port = in.readInt(); + return port; + } + + /** + * A helper function to serialize the JobTokenIdentifier to be sent to the + * ShuffleHandler as ServiceData. + * @param jobToken the job token to be used for authentication of + * shuffle data requests. + * @return the serialized version of the jobToken. + */ + public static ByteBuffer serializeServiceData(Token<JobTokenIdentifier> jobToken) throws IOException { + //TODO these bytes should be versioned + DataOutputBuffer jobToken_dob = new DataOutputBuffer(); + jobToken.write(jobToken_dob); + return ByteBuffer.wrap(jobToken_dob.getData(), 0, jobToken_dob.getLength()); + } + + static Token<JobTokenIdentifier> deserializeServiceData(ByteBuffer secret) throws IOException { + DataInputByteBuffer in = new DataInputByteBuffer(); + in.reset(secret); + Token<JobTokenIdentifier> jt = new Token<JobTokenIdentifier>(); + jt.readFields(in); + return jt; + } + + @Override + public void initializeApplication(ApplicationInitializationContext context) { + + String user = context.getUser(); + ApplicationId appId = context.getApplicationId(); + ByteBuffer secret = context.getApplicationDataForService(); + // TODO these bytes should be versioned + try { + Token<JobTokenIdentifier> jt = deserializeServiceData(secret); + // TODO: Once SHuffle is out of NM, this can use MR APIs + JobID jobId = new JobID(Long.toString(appId.getClusterTimestamp()), appId.getId()); + recordJobShuffleInfo(jobId, user, jt); + } catch (IOException e) { + LOG.error("Error during initApp", e); + // TODO add API to AuxiliaryServices to report failures + } + } + + @Override + public void stopApplication(ApplicationTerminationContext context) { + ApplicationId appId = context.getApplicationId(); + JobID jobId = new JobID(Long.toString(appId.getClusterTimestamp()), appId.getId()); + try { + removeJobShuffleInfo(jobId); + } catch (IOException e) { + LOG.error("Error during stopApp", e); + // TODO add API to AuxiliaryServices to report failures + } + } + + @Override + protected void serviceInit(Configuration conf) throws Exception { + manageOsCache = conf.getBoolean(SHUFFLE_MANAGE_OS_CACHE, + DEFAULT_SHUFFLE_MANAGE_OS_CACHE); + + readaheadLength = conf.getInt(SHUFFLE_READAHEAD_BYTES, + DEFAULT_SHUFFLE_READAHEAD_BYTES); + + maxShuffleConnections = conf.getInt(MAX_SHUFFLE_CONNECTIONS, + DEFAULT_MAX_SHUFFLE_CONNECTIONS); + int maxShuffleThreads = conf.getInt(MAX_SHUFFLE_THREADS, + DEFAULT_MAX_SHUFFLE_THREADS); + if (maxShuffleThreads == 0) { + maxShuffleThreads = 2 * Runtime.getRuntime().availableProcessors(); + } + + shuffleBufferSize = conf.getInt(SHUFFLE_BUFFER_SIZE, + DEFAULT_SHUFFLE_BUFFER_SIZE); + + shuffleTransferToAllowed = conf.getBoolean(SHUFFLE_TRANSFERTO_ALLOWED, + (Shell.WINDOWS)?WINDOWS_DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED: + DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED); + + maxSessionOpenFiles = conf.getInt(SHUFFLE_MAX_SESSION_OPEN_FILES, + DEFAULT_SHUFFLE_MAX_SESSION_OPEN_FILES); + + ThreadFactory bossFactory = new ThreadFactoryBuilder() + .setNameFormat("ShuffleHandler Netty Boss #%d") + .build(); + ThreadFactory workerFactory = new ThreadFactoryBuilder() + .setNameFormat("ShuffleHandler Netty Worker #%d") + .build(); + + selector = new NioServerSocketChannelFactory( + Executors.newCachedThreadPool(bossFactory), + Executors.newCachedThreadPool(workerFactory), + maxShuffleThreads); + super.serviceInit(new YarnConfiguration(conf)); + } + + // TODO change AbstractService to throw InterruptedException + @Override + protected void serviceStart() throws Exception { + Configuration conf = getConfig(); + userRsrc = new ConcurrentHashMap<String,String>(); + secretManager = new JobTokenSecretManager(); + recoverState(conf); + ServerBootstrap bootstrap = new ServerBootstrap(selector); + try { + pipelineFact = new HttpPipelineFactory(conf); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + bootstrap.setOption("child.keepAlive", true); + bootstrap.setPipelineFactory(pipelineFact); + port = conf.getInt(SHUFFLE_PORT_CONFIG_KEY, DEFAULT_SHUFFLE_PORT); + Channel ch = bootstrap.bind(new InetSocketAddress(port)); + accepted.add(ch); + port = ((InetSocketAddress)ch.getLocalAddress()).getPort(); + conf.set(SHUFFLE_PORT_CONFIG_KEY, Integer.toString(port)); + pipelineFact.SHUFFLE.setPort(port); + LOG.info(getName() + " listening on port " + port); + super.serviceStart(); + + sslFileBufferSize = conf.getInt(SUFFLE_SSL_FILE_BUFFER_SIZE_KEY, + DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE); + connectionKeepAliveEnabled = + conf.getBoolean(SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED, + DEFAULT_SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED); + connectionKeepAliveTimeOut = + Math.max(1, conf.getInt(SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT, + DEFAULT_SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT)); + mapOutputMetaInfoCacheSize = + Math.max(1, conf.getInt(SHUFFLE_MAPOUTPUT_META_INFO_CACHE_SIZE, + DEFAULT_SHUFFLE_MAPOUTPUT_META_INFO_CACHE_SIZE)); + } + + @Override + protected void serviceStop() throws Exception { + accepted.close().awaitUninterruptibly(10, TimeUnit.SECONDS); + if (selector != null) { + ServerBootstrap bootstrap = new ServerBootstrap(selector); + bootstrap.releaseExternalResources(); + } + if (pipelineFact != null) { + pipelineFact.destroy(); + } + if (stateDb != null) { + stateDb.close(); + } + super.serviceStop(); + } + + @Override + public synchronized ByteBuffer getMetaData() { + try { + return serializeMetaData(port); + } catch (IOException e) { + LOG.error("Error during getMeta", e); + // TODO add API to AuxiliaryServices to report failures + return null; + } + } + + protected Shuffle getShuffle(Configuration conf) { + return new Shuffle(conf); + } + + private void recoverState(Configuration conf) throws IOException { + Path recoveryRoot = getRecoveryPath(); + if (recoveryRoot != null) { + startStore(recoveryRoot); + Pattern jobPattern = Pattern.compile(JobID.JOBID_REGEX); + LeveldbIterator iter = null; + try { + iter = new LeveldbIterator(stateDb); + iter.seek(bytes(JobID.JOB)); + while (iter.hasNext()) { + Map.Entry<byte[],byte[]> entry = iter.next(); + String key = asString(entry.getKey()); + if (!jobPattern.matcher(key).matches()) { + break; + } + recoverJobShuffleInfo(key, entry.getValue()); + } + } catch (DBException e) { + throw new IOException("Database error during recovery", e); + } finally { + if (iter != null) { + iter.close(); + } + } + } + } + + private void startStore(Path recoveryRoot) throws IOException { + Options options = new Options(); + options.createIfMissing(false); + options.logger(new LevelDBLogger()); + Path dbPath = new Path(recoveryRoot, STATE_DB_NAME); + LOG.info("Using state database at " + dbPath + " for recovery"); + File dbfile = new File(dbPath.toString()); + try { + stateDb = JniDBFactory.factory.open(dbfile, options); + } catch (NativeDB.DBException e) { + if (e.isNotFound() || e.getMessage().contains(" does not exist ")) { + LOG.info("Creating state database at " + dbfile); + options.createIfMissing(true); + try { + stateDb = JniDBFactory.factory.open(dbfile, options); + storeVersion(); + } catch (DBException dbExc) { + throw new IOException("Unable to create state store", dbExc); + } + } else { + throw e; + } + } + checkVersion(); + } + + @VisibleForTesting + Version loadVersion() throws IOException { + byte[] data = stateDb.get(bytes(STATE_DB_SCHEMA_VERSION_KEY)); + // if version is not stored previously, treat it as CURRENT_VERSION_INFO. + if (data == null || data.length == 0) { + return getCurrentVersion(); + } + Version version = + new VersionPBImpl(VersionProto.parseFrom(data)); + return version; + } + + private void storeSchemaVersion(Version version) throws IOException { + String key = STATE_DB_SCHEMA_VERSION_KEY; + byte[] data = + ((VersionPBImpl) version).getProto().toByteArray(); + try { + stateDb.put(bytes(key), data); + } catch (DBException e) { + throw new IOException(e.getMessage(), e); + } + } + + private void storeVersion() throws IOException { + storeSchemaVersion(CURRENT_VERSION_INFO); + } + + // Only used for test + @VisibleForTesting + void storeVersion(Version version) throws IOException { + storeSchemaVersion(version); + } + + protected Version getCurrentVersion() { + return CURRENT_VERSION_INFO; + } + + /** + * 1) Versioning scheme: major.minor. For e.g. 1.0, 1.1, 1.2...1.25, 2.0 etc. + * 2) Any incompatible change of DB schema is a major upgrade, and any + * compatible change of DB schema is a minor upgrade. + * 3) Within a minor upgrade, say 1.1 to 1.2: + * overwrite the version info and proceed as normal. + * 4) Within a major upgrade, say 1.2 to 2.0: + * throw exception and indicate user to use a separate upgrade tool to + * upgrade shuffle info or remove incompatible old state. + */ + private void checkVersion() throws IOException { + Version loadedVersion = loadVersion(); + LOG.info("Loaded state DB schema version info " + loadedVersion); + if (loadedVersion.equals(getCurrentVersion())) { + return; + } + if (loadedVersion.isCompatibleTo(getCurrentVersion())) { + LOG.info("Storing state DB schedma version info " + getCurrentVersion()); + storeVersion(); + } else { + throw new IOException( + "Incompatible version for state DB schema: expecting DB schema version " + + getCurrentVersion() + ", but loading version " + loadedVersion); + } + } + + private void addJobToken(JobID jobId, String user, + Token<JobTokenIdentifier> jobToken) { + userRsrc.put(jobId.toString(), user); + secretManager.addTokenForJob(jobId.toString(), jobToken); + LOG.info("Added token for " + jobId.toString()); + } + + private void recoverJobShuffleInfo(String jobIdStr, byte[] data) + throws IOException { + JobID jobId; + try { + jobId = JobID.forName(jobIdStr); + } catch (IllegalArgumentException e) { + throw new IOException("Bad job ID " + jobIdStr + " in state store", e); + } + + JobShuffleInfoProto proto = JobShuffleInfoProto.parseFrom(data); + String user = proto.getUser(); + TokenProto tokenProto = proto.getJobToken(); + Token<JobTokenIdentifier> jobToken = new Token<JobTokenIdentifier>( + tokenProto.getIdentifier().toByteArray(), + tokenProto.getPassword().toByteArray(), + new Text(tokenProto.getKind()), new Text(tokenProto.getService())); + addJobToken(jobId, user, jobToken); + } + + private void recordJobShuffleInfo(JobID jobId, String user, + Token<JobTokenIdentifier> jobToken) throws IOException { + if (stateDb != null) { + TokenProto tokenProto = TokenProto.newBuilder() + .setIdentifier(ByteString.copyFrom(jobToken.getIdentifier())) + .setPassword(ByteString.copyFrom(jobToken.getPassword())) + .setKind(jobToken.getKind().toString()) + .setService(jobToken.getService().toString()) + .build(); + JobShuffleInfoProto proto = JobShuffleInfoProto.newBuilder() + .setUser(user).setJobToken(tokenProto).build(); + try { + stateDb.put(bytes(jobId.toString()), proto.toByteArray()); + } catch (DBException e) { + throw new IOException("Error storing " + jobId, e); + } + } + addJobToken(jobId, user, jobToken); + } + + private void removeJobShuffleInfo(JobID jobId) throws IOException { + String jobIdStr = jobId.toString(); + secretManager.removeTokenForJob(jobIdStr); + userRsrc.remove(jobIdStr); + if (stateDb != null) { + try { + stateDb.delete(bytes(jobIdStr)); + } catch (DBException e) { + throw new IOException("Unable to remove " + jobId + + " from state store", e); + } + } + } + + private static class LevelDBLogger implements Logger { + private static final Log LOG = LogFactory.getLog(LevelDBLogger.class); + + @Override + public void log(String message) { + LOG.info(message); + } + } + + class HttpPipelineFactory implements ChannelPipelineFactory { + + final Shuffle SHUFFLE; + private SSLFactory sslFactory; + + public HttpPipelineFactory(Configuration conf) throws Exception { + SHUFFLE = getShuffle(conf); + if (conf.getBoolean(MRConfig.SHUFFLE_SSL_ENABLED_KEY, + MRConfig.SHUFFLE_SSL_ENABLED_DEFAULT)) { + LOG.info("Encrypted shuffle is enabled."); + sslFactory = new SSLFactory(SSLFactory.Mode.SERVER, conf); + sslFactory.init(); + } + } + + public Shuffle getSHUFFLE() { + return SHUFFLE; + } + + public void destroy() { + if (sslFactory != null) { + sslFactory.destroy(); + } + } + + @Override + public ChannelPipeline getPipeline() throws Exception { + ChannelPipeline pipeline = Channels.pipeline(); + if (sslFactory != null) { + pipeline.addLast("ssl", new SslHandler(sslFactory.createSSLEngine())); + } + pipeline.addLast("decoder", new HttpRequestDecoder()); + pipeline.addLast("aggregator", new HttpChunkAggregator(1 << 16)); + pipeline.addLast("encoder", new HttpResponseEncoder()); + pipeline.addLast("chunking", new ChunkedWriteHandler()); + pipeline.addLast("shuffle", SHUFFLE); + return pipeline; + // TODO factor security manager into pipeline + // TODO factor out encode/decode to permit binary shuffle + // TODO factor out decode of index to permit alt. models + } + + } + + class Shuffle extends SimpleChannelUpstreamHandler { + + private static final int MAX_WEIGHT = 10 * 1024 * 1024; + private static final int EXPIRE_AFTER_ACCESS_MINUTES = 5; + private static final int ALLOWED_CONCURRENCY = 16; + private final Configuration conf; + private final IndexCache indexCache; + private final LocalDirAllocator lDirAlloc = + new LocalDirAllocator(YarnConfiguration.NM_LOCAL_DIRS); + private int port; + private final LoadingCache<AttemptPathIdentifier, AttemptPathInfo> pathCache = + CacheBuilder.newBuilder().expireAfterAccess(EXPIRE_AFTER_ACCESS_MINUTES, + TimeUnit.MINUTES).softValues().concurrencyLevel(ALLOWED_CONCURRENCY). + removalListener( + new RemovalListener<AttemptPathIdentifier, AttemptPathInfo>() { + @Override + public void onRemoval(RemovalNotification<AttemptPathIdentifier, + AttemptPathInfo> notification) { + if (LOG.isDebugEnabled()) { + LOG.debug("PathCache Eviction: " + notification.getKey() + + ", Reason=" + notification.getCause()); + } + } + } + ).maximumWeight(MAX_WEIGHT).weigher( + new Weigher<AttemptPathIdentifier, AttemptPathInfo>() { + @Override + public int weigh(AttemptPathIdentifier key, + AttemptPathInfo value) { + return key.jobId.length() + key.user.length() + + key.attemptId.length()+ + value.indexPath.toString().length() + + value.dataPath.toString().length(); + } + } + ).build(new CacheLoader<AttemptPathIdentifier, AttemptPathInfo>() { + @Override + public AttemptPathInfo load(AttemptPathIdentifier key) throws + Exception { + String base = getBaseLocation(key.jobId, key.user); + String attemptBase = base + key.attemptId; + Path indexFileName = lDirAlloc.getLocalPathToRead( + attemptBase + "/" + INDEX_FILE_NAME, conf); + Path mapOutputFileName = lDirAlloc.getLocalPathToRead( + attemptBase + "/" + DATA_FILE_NAME, conf); + + if (LOG.isDebugEnabled()) { + LOG.debug("Loaded : " + key + " via loader"); + } + return new AttemptPathInfo(indexFileName, mapOutputFileName); + } + }); + + public Shuffle(Configuration conf) { + this.conf = conf; + indexCache = new IndexCache(conf); + this.port = conf.getInt(SHUFFLE_PORT_CONFIG_KEY, DEFAULT_SHUFFLE_PORT); + } + + public void setPort(int port) { + this.port = port; + } + + private List<String> splitMaps(List<String> mapq) { + if (null == mapq) { + return null; + } + final List<String> ret = new ArrayList<String>(); + for (String s : mapq) { + Collections.addAll(ret, s.split(",")); + } + return ret; + } + + @Override + public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent evt) + throws Exception { + + if ((maxShuffleConnections > 0) && (accepted.size() >= maxShuffleConnections)) { + LOG.info(String.format("Current number of shuffle connections (%d) is " + + "greater than or equal to the max allowed shuffle connections (%d)", + accepted.size(), maxShuffleConnections)); + evt.getChannel().close(); + return; + } + accepted.add(evt.getChannel()); + super.channelOpen(ctx, evt); + } + + @Override + public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt) + throws Exception { + HttpRequest request = (HttpRequest) evt.getMessage(); + if (request.getMethod() != GET) { + sendError(ctx, METHOD_NOT_ALLOWED); + return; + } + // Check whether the shuffle version is compatible + if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals( + request.getHeader(ShuffleHeader.HTTP_HEADER_NAME)) + || !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals( + request.getHeader(ShuffleHeader.HTTP_HEADER_VERSION))) { + sendError(ctx, "Incompatible shuffle request version", BAD_REQUEST); + } + final Map<String,List<String>> q = + new QueryStringDecoder(request.getUri()).getParameters(); + final List<String> keepAliveList = q.get("keepAlive"); + boolean keepAliveParam = false; + if (keepAliveList != null && keepAliveList.size() == 1) { + keepAliveParam = Boolean.valueOf(keepAliveList.get(0)); + if (LOG.isDebugEnabled()) { + LOG.debug("KeepAliveParam : " + keepAliveList + + " : " + keepAliveParam); + } + } + final List<String> mapIds = splitMaps(q.get("map")); + final List<String> reduceQ = q.get("reduce"); + final List<String> jobQ = q.get("job"); + if (LOG.isDebugEnabled()) { + LOG.debug("RECV: " + request.getUri() + + "\n mapId: " + mapIds + + "\n reduceId: " + reduceQ + + "\n jobId: " + jobQ + + "\n keepAlive: " + keepAliveParam); + } + + if (mapIds == null || reduceQ == null || jobQ == null) { + sendError(ctx, "Required param job, map and reduce", BAD_REQUEST); + return; + } + if (reduceQ.size() != 1 || jobQ.size() != 1) { + sendError(ctx, "Too many job/reduce parameters", BAD_REQUEST); + return; + } + + // this audit log is disabled by default, + // to turn it on please enable this audit log + // on log4j.properties by uncommenting the setting + if (AUDITLOG.isDebugEnabled()) { + AUDITLOG.debug("shuffle for " + jobQ.get(0) + + " reducer " + reduceQ.get(0)); + } + int reduceId; + String jobId; + try { + reduceId = Integer.parseInt(reduceQ.get(0)); + jobId = jobQ.get(0); + } catch (NumberFormatException e) { + sendError(ctx, "Bad reduce parameter", BAD_REQUEST); + return; + } catch (IllegalArgumentException e) { + sendError(ctx, "Bad job parameter", BAD_REQUEST); + return; + } + final String reqUri = request.getUri(); + if (null == reqUri) { + // TODO? add upstream? + sendError(ctx, FORBIDDEN); + return; + } + HttpResponse response = new DefaultHttpResponse(HTTP_1_1, OK); + try { + verifyRequest(jobId, ctx, request, response, + new URL("http", "", this.port, reqUri)); + } catch (IOException e) { + LOG.warn("Shuffle failure ", e); + sendError(ctx, e.getMessage(), UNAUTHORIZED); + return; + } + + Map<String, MapOutputInfo> mapOutputInfoMap = + new HashMap<String, MapOutputInfo>(); + Channel ch = evt.getChannel(); + String user = userRsrc.get(jobId); + + try { + populateHeaders(mapIds, jobId, user, reduceId, request, + response, keepAliveParam, mapOutputInfoMap); + } catch(IOException e) { + ch.write(response); + LOG.error("Shuffle error in populating headers :", e); + String errorMessage = getErrorMessage(e); + sendError(ctx,errorMessage , INTERNAL_SERVER_ERROR); + return; + } + ch.write(response); + //Initialize one ReduceContext object per messageReceived call + ReduceContext reduceContext = new ReduceContext(mapIds, reduceId, ctx, + user, mapOutputInfoMap, jobId); + for (int i = 0; i < Math.min(maxSessionOpenFiles, mapIds.size()); i++) { + ChannelFuture nextMap = sendMap(reduceContext); + if(nextMap == null) { + return; + } + } + } + + /** + * Calls sendMapOutput for the mapId pointed by ReduceContext.mapsToSend + * and increments it. This method is first called by messageReceived() + * maxSessionOpenFiles times and then on the completion of every + * sendMapOutput operation. This limits the number of open files on a node, + * which can get really large(exhausting file descriptors on the NM) if all + * sendMapOutputs are called in one go, as was done previous to this change. + * @param reduceContext used to call sendMapOutput with correct params. + * @return the ChannelFuture of the sendMapOutput, can be null. + */ + public ChannelFuture sendMap(ReduceContext reduceContext) + throws Exception { + + ChannelFuture nextMap = null; + if (reduceContext.getMapsToSend().get() < + reduceContext.getMapIds().size()) { + int nextIndex = reduceContext.getMapsToSend().getAndIncrement(); + String mapId = reduceContext.getMapIds().get(nextIndex); + + try { + MapOutputInfo info = reduceContext.getInfoMap().get(mapId); + if (info == null) { + info = getMapOutputInfo(mapId, reduceContext.getReduceId(), + reduceContext.getJobId(), reduceContext.getUser()); + } + nextMap = sendMapOutput( + reduceContext.getCtx(), + reduceContext.getCtx().getChannel(), + reduceContext.getUser(), mapId, + reduceContext.getReduceId(), info); + if (null == nextMap) { + sendError(reduceContext.getCtx(), NOT_FOUND); + return null; + } + nextMap.addListener(new ReduceMapFileCount(reduceContext)); + } catch (IOException e) { + LOG.error("Shuffle error :", e); + String errorMessage = getErrorMessage(e); + sendError(reduceContext.getCtx(), errorMessage, + INTERNAL_SERVER_ERROR); + return null; + } + } + return nextMap; + } + + private String getErrorMessage(Throwable t) { + StringBuffer sb = new StringBuffer(t.getMessage()); + while (t.getCause() != null) { + sb.append(t.getCause().getMessage()); + t = t.getCause(); + } + return sb.toString(); + } + + private String getBaseLocation(String jobId, String user) { + final JobID jobID = JobID.forName(jobId); + final ApplicationId appID = + ApplicationId.newInstance(Long.parseLong(jobID.getJtIdentifier()), + jobID.getId()); + final String baseStr = + USERCACHE + "/" + user + "/" + + APPCACHE + "/" + + appID.toString() + "/output" + "/"; + return baseStr; + } + + protected MapOutputInfo getMapOutputInfo(String mapId, int reduce, + String jobId, String user) throws IOException { + AttemptPathInfo pathInfo; + try { + AttemptPathIdentifier identifier = new AttemptPathIdentifier( + jobId, user, mapId); + pathInfo = pathCache.get(identifier); + if (LOG.isDebugEnabled()) { + LOG.debug("Retrieved pathInfo for " + identifier + + " check for corresponding loaded messages to determine whether" + + " it was loaded or cached"); + } + } catch (ExecutionException e) { + if (e.getCause() instanceof IOException) { + throw (IOException) e.getCause(); + } else { + throw new RuntimeException(e.getCause()); + } + } + + TezIndexRecord info = + indexCache.getIndexInformation(mapId, reduce, pathInfo.indexPath, user); + + if (LOG.isDebugEnabled()) { + LOG.debug("getMapOutputInfo: jobId=" + jobId + ", mapId=" + mapId + + ",dataFile=" + pathInfo.dataPath + ", indexFile=" + + pathInfo.indexPath); + } + + MapOutputInfo outputInfo = new MapOutputInfo(pathInfo.dataPath, info); + return outputInfo; + } + + protected void populateHeaders(List<String> mapIds, String jobId, + String user, int reduce, HttpRequest request, HttpResponse response, + boolean keepAliveParam, Map<String, MapOutputInfo> mapOutputInfoMap) + throws IOException { + + long contentLength = 0; + for (String mapId : mapIds) { + MapOutputInfo outputInfo = getMapOutputInfo(mapId, reduce, jobId, user); + if (mapOutputInfoMap.size() < mapOutputMetaInfoCacheSize) { + mapOutputInfoMap.put(mapId, outputInfo); + } + + ShuffleHeader header = + new ShuffleHeader(mapId, outputInfo.indexRecord.getPartLength(), + outputInfo.indexRecord.getRawLength(), reduce); + DataOutputBuffer dob = new DataOutputBuffer(); + header.write(dob); + + contentLength += outputInfo.indexRecord.getPartLength(); + contentLength += dob.getLength(); + } + + // Now set the response headers. + setResponseHeaders(response, keepAliveParam, contentLength); + } + + protected void setResponseHeaders(HttpResponse response, + boolean keepAliveParam, long contentLength) { + if (!connectionKeepAliveEnabled && !keepAliveParam) { + if (LOG.isDebugEnabled()) { + LOG.debug("Setting connection close header..."); + } + response.setHeader(HttpHeaders.CONNECTION, CONNECTION_CLOSE); + } else { + response.setHeader(HttpHeaders.CONTENT_LENGTH, + String.valueOf(contentLength)); + response.setHeader(HttpHeaders.CONNECTION, HttpHeaders.KEEP_ALIVE); + response.setHeader(HttpHeaders.KEEP_ALIVE, "timeout=" + + connectionKeepAliveTimeOut); + LOG.info("Content Length in shuffle : " + contentLength); + } + } + + class MapOutputInfo { + final Path mapOutputFileName; + final TezIndexRecord indexRecord; + + MapOutputInfo(Path mapOutputFileName, TezIndexRecord indexRecord) { + this.mapOutputFileName = mapOutputFileName; + this.indexRecord = indexRecord; + } + } + + protected void verifyRequest(String appid, ChannelHandlerContext ctx, + HttpRequest request, HttpResponse response, URL requestUri) + throws IOException { + SecretKey tokenSecret = secretManager.retrieveTokenSecret(appid); + if (null == tokenSecret) { + LOG.info("Request for unknown token " + appid); + throw new IOException("could not find jobid"); + } + // string to encrypt + String enc_str = SecureShuffleUtils.buildMsgFrom(requestUri); + // hash from the fetcher + String urlHashStr = + request.getHeader(SecureShuffleUtils.HTTP_HEADER_URL_HASH); + if (urlHashStr == null) { + LOG.info("Missing header hash for " + appid); + throw new IOException("fetcher cannot be authenticated"); + } + if (LOG.isDebugEnabled()) { + int len = urlHashStr.length(); + LOG.debug("verifying request. enc_str=" + enc_str + "; hash=..." + + urlHashStr.substring(len-len/2, len-1)); + } + // verify - throws exception + SecureShuffleUtils.verifyReply(urlHashStr, enc_str, tokenSecret); + // verification passed - encode the reply + String reply = + SecureShuffleUtils.generateHash(urlHashStr.getBytes(Charsets.UTF_8), + tokenSecret); + response.setHeader(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH, reply); + // Put shuffle version into http header + response.setHeader(ShuffleHeader.HTTP_HEADER_NAME, + ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); + response.setHeader(ShuffleHeader.HTTP_HEADER_VERSION, + ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); + if (LOG.isDebugEnabled()) { + int len = reply.length(); + LOG.debug("Fetcher request verfied. enc_str=" + enc_str + ";reply=" + + reply.substring(len-len/2, len-1)); + } + } + + protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, Channel ch, + String user, String mapId, int reduce, MapOutputInfo mapOutputInfo) + throws IOException { + final TezIndexRecord info = mapOutputInfo.indexRecord; + final ShuffleHeader header = + new ShuffleHeader(mapId, info.getPartLength(), info.getRawLength(), reduce); + final DataOutputBuffer dob = new DataOutputBuffer(); + header.write(dob); + ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength())); + final File spillfile = + new File(mapOutputInfo.mapOutputFileName.toString()); + RandomAccessFile spill; + try { + spill = SecureIOUtils.openForRandomRead(spillfile, "r", user, null); + } catch (FileNotFoundException e) { + LOG.info(spillfile + " not found"); + return null; + } + ChannelFuture writeFuture; + if (ch.getPipeline().get(SslHandler.class) == null) { + final FadvisedFileRegion partition = new FadvisedFileRegion(spill, + info.getStartOffset(), info.getPartLength(), manageOsCache, readaheadLength, + readaheadPool, spillfile.getAbsolutePath(), + shuffleBufferSize, shuffleTransferToAllowed); + writeFuture = ch.write(partition); + writeFuture.addListener(new ChannelFutureListener() { + // TODO error handling; distinguish IO/connection failures, + // attribute to appropriate spill output + @Override + public void operationComplete(ChannelFuture future) { + if (future.isSuccess()) { + partition.transferSuccessful(); + } + partition.releaseExternalResources(); + } + }); + } else { + // HTTPS cannot be done with zero copy. + final FadvisedChunkedFile chunk = new FadvisedChunkedFile(spill, + info.getStartOffset(), info.getPartLength(), sslFileBufferSize, + manageOsCache, readaheadLength, readaheadPool, + spillfile.getAbsolutePath()); + writeFuture = ch.write(chunk); + } + metrics.shuffleConnections.incr(); + metrics.shuffleOutputBytes.incr(info.getPartLength()); // optimistic + return writeFuture; + } + + protected void sendError(ChannelHandlerContext ctx, + HttpResponseStatus status) { + sendError(ctx, "", status); + } + + protected void sendError(ChannelHandlerContext ctx, String message, + HttpResponseStatus status) { + HttpResponse response = new DefaultHttpResponse(HTTP_1_1, status); + response.setHeader(CONTENT_TYPE, "text/plain; charset=UTF-8"); + // Put shuffle version into http header + response.setHeader(ShuffleHeader.HTTP_HEADER_NAME, + ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); + response.setHeader(ShuffleHeader.HTTP_HEADER_VERSION, + ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); + response.setContent( + ChannelBuffers.copiedBuffer(message, CharsetUtil.UTF_8)); + + // Close the connection as soon as the error message is sent. + ctx.getChannel().write(response).addListener(ChannelFutureListener.CLOSE); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) + throws Exception { + Channel ch = e.getChannel(); + Throwable cause = e.getCause(); + if (cause instanceof TooLongFrameException) { + sendError(ctx, BAD_REQUEST); + return; + } else if (cause instanceof IOException) { + if (cause instanceof ClosedChannelException) { + LOG.debug("Ignoring closed channel error", cause); + return; + } + String message = String.valueOf(cause.getMessage()); + if (IGNORABLE_ERROR_MESSAGE.matcher(message).matches()) { + LOG.debug("Ignoring client socket close", cause); + return; + } + } + + LOG.error("Shuffle error: ", cause); + if (ch.isConnected()) { + LOG.error("Shuffle error " + e); + sendError(ctx, INTERNAL_SERVER_ERROR); + } + } + } + + static class AttemptPathInfo { + // TODO Change this over to just store local dir indices, instead of the + // entire path. Far more efficient. + private final Path indexPath; + private final Path dataPath; + + public AttemptPathInfo(Path indexPath, Path dataPath) { + this.indexPath = indexPath; + this.dataPath = dataPath; + } + } + + static class AttemptPathIdentifier { + private final String jobId; + private final String user; + private final String attemptId; + + public AttemptPathIdentifier(String jobId, String user, String attemptId) { + this.jobId = jobId; + this.user = user; + this.attemptId = attemptId; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + AttemptPathIdentifier that = (AttemptPathIdentifier) o; + + if (!attemptId.equals(that.attemptId)) { + return false; + } + if (!jobId.equals(that.jobId)) { + return false; + } + + return true; + } + + @Override + public int hashCode() { + int result = jobId.hashCode(); + result = 31 * result + attemptId.hashCode(); + return result; + } + + @Override + public String toString() { + return "AttemptPathIdentifier{" + + "attemptId='" + attemptId + '\'' + + ", jobId='" + jobId + '\'' + + '}'; + } + } +}
