This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new c6d186cb [#854][FOLLOWUP] feat(tez): Add RssShuffleScheduler to run
and manager shuffle work (#948)
c6d186cb is described below
commit c6d186cb07732218642a58a10080ace6a1509898
Author: Qing <[email protected]>
AuthorDate: Thu Jun 15 18:50:33 2023 +0800
[#854][FOLLOWUP] feat(tez): Add RssShuffleScheduler to run and manager
shuffle work (#948)
### What changes were proposed in this pull request?
Add RssShuffleScheduler to run and manager shuffle work
### Why are the changes needed?
Fix: https://github.com/apache/incubator-uniffle/issues/854
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
unit test
---
.../java/org/apache/tez/common/RssTezConfig.java | 12 +
.../orderedgrouped/RssShuffleScheduler.java | 1690 ++++++++++++++++++++
.../orderedgrouped/RssShuffleSchedulerTest.java | 912 +++++++++++
3 files changed, 2614 insertions(+)
diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
b/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
index f544773f..25a382c7 100644
--- a/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
+++ b/client-tez/src/main/java/org/apache/tez/common/RssTezConfig.java
@@ -102,6 +102,8 @@ public class RssTezConfig {
public static final String DEBUG_HIVE_TEZ_LOG_LEVEL = "debug";
public static final String RSS_STORAGE_TYPE = TEZ_RSS_CONFIG_PREFIX +
RssClientConfig.RSS_STORAGE_TYPE;
+ public static final String RSS_STORAGE_TYPE_DEFAULT_VALUE =
"MEMORY_LOCALFILE";
+
public static final String RSS_DYNAMIC_CLIENT_CONF_ENABLED =
TEZ_RSS_CONFIG_PREFIX +
RssClientConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED;
@@ -141,6 +143,16 @@ public class RssTezConfig {
public static final int
RSS_ESTIMATE_TASK_CONCURRENCY_PER_SERVER_DEFAULT_VALUE =
RssClientConfig.RSS_ESTIMATE_TASK_CONCURRENCY_PER_SERVER_DEFAULT_VALUE;
+ public static final String RSS_CLIENT_READ_BUFFER_SIZE =
+ TEZ_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_READ_BUFFER_SIZE;
+ public static final String RSS_CLIENT_READ_BUFFER_SIZE_DEFAULT_VALUE =
+ RssClientConfig.RSS_CLIENT_READ_BUFFER_SIZE_DEFAULT_VALUE;
+
+ public static final String RSS_PARTITION_NUM_PER_RANGE =
+ TEZ_RSS_CONFIG_PREFIX + RssClientConfig.RSS_PARTITION_NUM_PER_RANGE;
+ public static final int RSS_PARTITION_NUM_PER_RANGE_DEFAULT_VALUE =
+ RssClientConfig.RSS_PARTITION_NUM_PER_RANGE_DEFAULT_VALUE;
+
public static final String RSS_CONF_FILE = "rss_conf.xml";
public static final String RSS_REMOTE_STORAGE_PATH =
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
new file mode 100644
index 00000000..c1d816f8
--- /dev/null
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
@@ -0,0 +1,1690 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.orderedgrouped;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.text.DecimalFormat;
+import java.util.ArrayList;
+import java.util.BitSet;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.DelayQueue;
+import java.util.concurrent.Delayed;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.crypto.SecretKey;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.LinkedListMultimap;
+import com.google.common.collect.ListMultimap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.compress.CompressionCodec;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.tez.common.CallableWithNdc;
+import org.apache.tez.common.IdUtils;
+import org.apache.tez.common.InputContextUtils;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.common.TezIdHelper;
+import org.apache.tez.common.TezUtilsInternal;
+import org.apache.tez.common.UmbilicalUtils;
+import org.apache.tez.common.counters.TaskCounter;
+import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.TezException;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.http.HttpConnectionParams;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.api.events.InputReadErrorEvent;
+import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
+import org.apache.tez.runtime.library.common.CompositeInputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.TezRuntimeUtils;
+import org.apache.tez.runtime.library.common.shuffle.HostPort;
+import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
+import
org.apache.tez.runtime.library.common.shuffle.ShuffleUtils.FetchStatsLogger;
+import
org.apache.tez.runtime.library.common.shuffle.orderedgrouped.MapHost.HostPortPartition;
+import
org.apache.tez.runtime.library.common.shuffle.orderedgrouped.MapOutput.Type;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleReadClient;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.factory.ShuffleClientFactory;
+import org.apache.uniffle.client.request.CreateShuffleReadClientRequest;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.UnitConverter;
+
+class RssShuffleScheduler extends ShuffleScheduler {
+
+ public static class PathPartition {
+
+ final String path;
+ final int partition;
+
+ PathPartition(String path, int partition) {
+ this.path = path;
+ this.partition = partition;
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int result = 1;
+ result = prime * result + ((path == null) ? 0 : path.hashCode());
+ result = prime * result + partition;
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ PathPartition other = (PathPartition) obj;
+ if (path == null) {
+ if (other.path != null) {
+ return false;
+ }
+ } else if (!path.equals(other.path)) {
+ return false;
+ }
+ if (partition != other.partition) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public String toString() {
+ return "PathPartition [path=" + path + ", partition=" + partition + "]";
+ }
+ }
+
+ @VisibleForTesting
+ enum ShuffleErrors {
+ IO_ERROR,
+ WRONG_LENGTH,
+ BAD_ID,
+ WRONG_MAP,
+ CONNECTION,
+ WRONG_REDUCE
+ }
+
+ @VisibleForTesting
+ static final String SHUFFLE_ERR_GRP_NAME = "Shuffle Errors";
+
+ private final AtomicLong shuffleStart = new AtomicLong(0);
+
+ private static final Logger LOG =
LoggerFactory.getLogger(RssShuffleScheduler.class);
+ private static final Logger LOG_FETCH =
LoggerFactory.getLogger(LOG.getName() + ".fetch");
+ private static final FetchStatsLogger fetchStatsLogger = new
FetchStatsLogger(LOG_FETCH, LOG);
+
+ static final long INITIAL_PENALTY = 2000L; // 2 seconds
+ private static final float PENALTY_GROWTH_RATE = 1.3f;
+
+ private final BitSet finishedMaps;
+ private final int numInputs;
+ private int numFetchedSpills;
+ @VisibleForTesting
+ final Map<HostPortPartition, MapHost> mapLocations = new HashMap<>();
+ //TODO Clean this and other maps at some point
+ @VisibleForTesting
+ final ConcurrentMap<PathPartition, InputAttemptIdentifier>
pathToIdentifierMap
+ = new ConcurrentHashMap<>();
+
+ // To track shuffleInfo events when finalMerge is disabled in source or
pipelined shuffle is
+ // enabled in source.
+ @VisibleForTesting
+ final Map<Integer, ShuffleEventInfo> pipelinedShuffleInfoEventsMap;
+
+ @VisibleForTesting
+ final Set<MapHost> pendingHosts = new HashSet<>();
+ private final Set<InputAttemptIdentifier> obsoleteInputs = new HashSet<>();
+
+ private final AtomicBoolean isShutdown = new AtomicBoolean(false);
+ private final Random random = new Random(System.currentTimeMillis());
+ private final DelayQueue<Penalty> penalties = new DelayQueue<>();
+ private final Referee referee;
+ @VisibleForTesting
+ final Map<InputAttemptIdentifier, IntWritable> failureCounts = new
HashMap<>();
+ final Set<HostPort> uniqueHosts = Sets.newHashSet();
+ private final Map<HostPort,IntWritable> hostFailures = new HashMap<>();
+ private final InputContext inputContext;
+ private final TezCounter shuffledInputsCounter;
+ private final TezCounter skippedInputCounter;
+ private final TezCounter reduceShuffleBytes;
+ private final TezCounter reduceBytesDecompressed;
+ @VisibleForTesting
+ final TezCounter failedShuffleCounter;
+ private final TezCounter bytesShuffledToDisk;
+ private final TezCounter bytesShuffledToDiskDirect;
+ private final TezCounter bytesShuffledToMem;
+ private final TezCounter firstEventReceived;
+ private final TezCounter lastEventReceived;
+
+ private final String srcNameTrimmed;
+ @VisibleForTesting
+ final AtomicInteger remainingMaps;
+ private final long startTime;
+ @VisibleForTesting
+ long lastProgressTime;
+ @VisibleForTesting
+ long failedShufflesSinceLastCompletion;
+
+ private final int numFetchers;
+ private final Set<RssTezShuffleDataFetcher> rssRunningFetchers =
Collections.newSetFromMap(new ConcurrentHashMap<>());
+
+ private final ListeningExecutorService fetcherExecutor;
+
+ private final HttpConnectionParams httpConnectionParams;
+ private final FetchedInputAllocatorOrderedGrouped allocator;
+ private final ExceptionReporter exceptionReporter;
+ private final MergeManager mergeManager;
+ private final JobTokenSecretManager jobTokenSecretManager;
+ private final boolean ifileReadAhead;
+ private final int ifileReadAheadLength;
+ private final CompressionCodec codec;
+ private final Configuration conf;
+ private final boolean localDiskFetchEnabled;
+ private final String localHostname;
+ private final int shufflePort;
+ private final String applicationId;
+ private final int dagId;
+ private final boolean asyncHttp;
+ private final boolean sslShuffle;
+
+ private final TezCounter ioErrsCounter;
+ private final TezCounter wrongLengthErrsCounter;
+ private final TezCounter badIdErrsCounter;
+ private final TezCounter wrongMapErrsCounter;
+ private final TezCounter connectionErrsCounter;
+ private final TezCounter wrongReduceErrsCounter;
+
+ private final int maxTaskOutputAtOnce;
+ private final int maxFetchFailuresBeforeReporting;
+ private final boolean reportReadErrorImmediately;
+ private final int maxFailedUniqueFetches;
+ private final int abortFailureLimit;
+
+ private final int minFailurePerHost;
+ private final float hostFailureFraction;
+ private final float maxStallTimeFraction;
+ private final float minReqProgressFraction;
+ private final float maxAllowedFailedFetchFraction;
+ private final boolean checkFailedFetchSinceLastCompletion;
+ private final boolean verifyDiskChecksum;
+ private final boolean compositeFetch;
+
+ private volatile Thread shuffleSchedulerThread = null;
+
+ private long totalBytesShuffledTillNow = 0;
+ private final DecimalFormat mbpsFormat = new DecimalFormat("0.00");
+
+
+ // For Rss
+ private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+ private final Map<Integer, MapHost> runningRssPartitionMap = new HashMap<>();
+
+ private final Set<Integer> successRssPartitionSet =
Sets.newConcurrentHashSet();
+ private final Set<Integer> allRssPartition = Sets.newConcurrentHashSet();
+
+ private final Map<Integer, Set<InputAttemptIdentifier>>
partitionIdToSuccessMapTaskAttempts = new HashMap<>();
+ private final String storageType;
+
+
+ private final int readBufferSize;
+ private final int partitionNumPerRange;
+ private String basePath;
+ private int indexReadLimit;
+
+ RssShuffleScheduler(InputContext inputContext,
+ Configuration conf,
+ int numberOfInputs,
+ ExceptionReporter exceptionReporter,
+ MergeManager mergeManager,
+ FetchedInputAllocatorOrderedGrouped allocator,
+ long startTime,
+ CompressionCodec codec,
+ boolean ifileReadAhead,
+ int ifileReadAheadLength,
+ String srcNameTrimmed) throws IOException {
+ super(inputContext, conf, numberOfInputs, exceptionReporter, mergeManager,
allocator, startTime, codec,
+ ifileReadAhead, ifileReadAheadLength, srcNameTrimmed);
+ this.inputContext = inputContext;
+ this.conf = conf;
+ this.exceptionReporter = exceptionReporter;
+ this.allocator = allocator;
+ this.mergeManager = mergeManager;
+ this.numInputs = numberOfInputs;
+ int abortFailureLimitConf = conf.getInt(TezRuntimeConfiguration
+ .TEZ_RUNTIME_SHUFFLE_SOURCE_ATTEMPT_ABORT_LIMIT,
TezRuntimeConfiguration
+ .TEZ_RUNTIME_SHUFFLE_SOURCE_ATTEMPT_ABORT_LIMIT_DEFAULT);
+ if (abortFailureLimitConf <= -1) {
+ abortFailureLimit = Math.max(15, numberOfInputs / 10);
+ } else {
+ //No upper cap, as user is setting this intentionally
+ abortFailureLimit = abortFailureLimitConf;
+ }
+ remainingMaps = new AtomicInteger(numberOfInputs); // total up-stream task
+
+ finishedMaps = new BitSet(numberOfInputs);
+ this.ifileReadAhead = ifileReadAhead;
+ this.ifileReadAheadLength = ifileReadAheadLength;
+ this.srcNameTrimmed = srcNameTrimmed;
+ this.codec = codec;
+ int configuredNumFetchers =
+ conf.getInt(
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_PARALLEL_COPIES,
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_PARALLEL_COPIES_DEFAULT);
+ numFetchers = Math.min(configuredNumFetchers, numInputs);
+
+ localDiskFetchEnabled = conf.getBoolean(
+ TezRuntimeConfiguration.TEZ_RUNTIME_OPTIMIZE_LOCAL_FETCH,
+ TezRuntimeConfiguration.TEZ_RUNTIME_OPTIMIZE_LOCAL_FETCH_DEFAULT);
+
+ this.minFailurePerHost = conf.getInt(
+ TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MIN_FAILURES_PER_HOST,
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MIN_FAILURES_PER_HOST_DEFAULT);
+ Preconditions.checkArgument(minFailurePerHost >= 0,
+ TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MIN_FAILURES_PER_HOST
+ + "=" + minFailurePerHost + " should not be negative");
+
+ this.hostFailureFraction = conf.getFloat(TezRuntimeConfiguration
+
.TEZ_RUNTIME_SHUFFLE_ACCEPTABLE_HOST_FETCH_FAILURE_FRACTION,
+ TezRuntimeConfiguration
+
.TEZ_RUNTIME_SHUFFLE_ACCEPTABLE_HOST_FETCH_FAILURE_FRACTION_DEFAULT);
+
+ this.maxStallTimeFraction = conf.getFloat(
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MAX_STALL_TIME_FRACTION,
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MAX_STALL_TIME_FRACTION_DEFAULT);
+ Preconditions.checkArgument(maxStallTimeFraction >= 0,
+ TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MAX_STALL_TIME_FRACTION
+ + "=" + maxStallTimeFraction + " should not be negative");
+
+ this.minReqProgressFraction = conf.getFloat(
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MIN_REQUIRED_PROGRESS_FRACTION,
+ TezRuntimeConfiguration
+
.TEZ_RUNTIME_SHUFFLE_MIN_REQUIRED_PROGRESS_FRACTION_DEFAULT);
+ Preconditions.checkArgument(minReqProgressFraction >= 0,
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MIN_REQUIRED_PROGRESS_FRACTION
+ + "=" + minReqProgressFraction + " should not be
negative");
+
+ this.maxAllowedFailedFetchFraction = conf.getFloat(
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MAX_ALLOWED_FAILED_FETCH_ATTEMPT_FRACTION,
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MAX_ALLOWED_FAILED_FETCH_ATTEMPT_FRACTION_DEFAULT);
+ Preconditions.checkArgument(maxAllowedFailedFetchFraction >= 0,
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MAX_ALLOWED_FAILED_FETCH_ATTEMPT_FRACTION
+ + "=" + maxAllowedFailedFetchFraction + " should not be
negative");
+
+ this.checkFailedFetchSinceLastCompletion = conf.getBoolean(
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FAILED_CHECK_SINCE_LAST_COMPLETION,
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FAILED_CHECK_SINCE_LAST_COMPLETION_DEFAULT);
+
+ this.applicationId = IdUtils.getApplicationAttemptId().toString();
+ this.dagId = inputContext.getDagIdentifier();
+ this.localHostname = inputContext.getExecutionContext().getHostName();
+ String auxiliaryService =
conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID,
+ TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT);
+ final ByteBuffer shuffleMetadata =
+ inputContext.getServiceProviderMetaData(auxiliaryService);
+ this.shufflePort =
ShuffleUtils.deserializeShuffleProviderMetaData(shuffleMetadata);
+
+ this.referee = new Referee();
+ // Counters used by the ShuffleScheduler
+ this.shuffledInputsCounter =
inputContext.getCounters().findCounter(TaskCounter.NUM_SHUFFLED_INPUTS);
+ this.reduceShuffleBytes =
inputContext.getCounters().findCounter(TaskCounter.SHUFFLE_BYTES);
+ this.reduceBytesDecompressed =
inputContext.getCounters().findCounter(TaskCounter.SHUFFLE_BYTES_DECOMPRESSED);
+ this.failedShuffleCounter =
inputContext.getCounters().findCounter(TaskCounter.NUM_FAILED_SHUFFLE_INPUTS);
+ this.bytesShuffledToDisk =
inputContext.getCounters().findCounter(TaskCounter.SHUFFLE_BYTES_TO_DISK);
+ this.bytesShuffledToDiskDirect =
inputContext.getCounters().findCounter(TaskCounter.SHUFFLE_BYTES_DISK_DIRECT);
+ this.bytesShuffledToMem =
inputContext.getCounters().findCounter(TaskCounter.SHUFFLE_BYTES_TO_MEM);
+
+ // Counters used by Fetchers
+ ioErrsCounter =
inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME,
ShuffleErrors.IO_ERROR.toString());
+ wrongLengthErrsCounter =
inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME,
+ ShuffleErrors.WRONG_LENGTH.toString());
+ badIdErrsCounter =
inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME,
ShuffleErrors.BAD_ID.toString());
+ wrongMapErrsCounter =
inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME,
+ ShuffleErrors.WRONG_MAP.toString());
+ connectionErrsCounter =
inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME,
+ ShuffleErrors.CONNECTION.toString());
+ wrongReduceErrsCounter =
inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME,
+ ShuffleErrors.WRONG_REDUCE.toString());
+
+ this.startTime = startTime;
+ this.lastProgressTime = startTime;
+
+ this.sslShuffle =
conf.getBoolean(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_ENABLE_SSL,
+ TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_ENABLE_SSL_DEFAULT);
+ this.asyncHttp =
conf.getBoolean(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_USE_ASYNC_HTTP,
false);
+ this.httpConnectionParams = ShuffleUtils.getHttpConnectionParams(conf);
+ SecretKey jobTokenSecret = ShuffleUtils
+ .getJobTokenSecretFromTokenBytes(inputContext
+ .getServiceConsumerMetaData(auxiliaryService));
+ this.jobTokenSecretManager = new JobTokenSecretManager(jobTokenSecret);
+
+ final ExecutorService fetcherRawExecutor;
+ if
(conf.getBoolean(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCHER_USE_SHARED_POOL,
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCHER_USE_SHARED_POOL_DEFAULT)) {
+ fetcherRawExecutor =
inputContext.createTezFrameworkExecutorService(numFetchers,
+ "Fetcher_O {" + srcNameTrimmed + "} #%d");
+ } else {
+ fetcherRawExecutor = Executors.newFixedThreadPool(numFetchers, new
ThreadFactoryBuilder()
+ .setDaemon(true).setNameFormat("Fetcher_O {" + srcNameTrimmed +
"} #%d").build());
+ }
+ this.fetcherExecutor =
MoreExecutors.listeningDecorator(fetcherRawExecutor);
+
+ this.maxFailedUniqueFetches = Math.min(numberOfInputs, 5);
+ referee.start();
+ this.maxFetchFailuresBeforeReporting =
+ conf.getInt(
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_FAILURES_LIMIT,
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_FAILURES_LIMIT_DEFAULT);
+ this.reportReadErrorImmediately =
+ conf.getBoolean(
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_NOTIFY_READERROR,
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_NOTIFY_READERROR_DEFAULT);
+ this.verifyDiskChecksum = conf.getBoolean(
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_VERIFY_DISK_CHECKSUM,
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_VERIFY_DISK_CHECKSUM_DEFAULT);
+
+ /**
+ * Setting to very high val can lead to Http 400 error. Cap it to 75;
every attempt id would
+ * be approximately 48 bytes; 48 * 75 = 3600 which should give some room
for other info in URL.
+ */
+ this.maxTaskOutputAtOnce = Math.max(1, Math.min(75, conf.getInt(
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_MAX_TASK_OUTPUT_AT_ONCE,
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_MAX_TASK_OUTPUT_AT_ONCE_DEFAULT)));
+
+ this.skippedInputCounter =
inputContext.getCounters().findCounter(TaskCounter.NUM_SKIPPED_INPUTS);
+ this.firstEventReceived =
inputContext.getCounters().findCounter(TaskCounter.FIRST_EVENT_RECEIVED);
+ this.lastEventReceived =
inputContext.getCounters().findCounter(TaskCounter.LAST_EVENT_RECEIVED);
+ this.compositeFetch = ShuffleUtils.isTezShuffleHandler(conf);
+
+ pipelinedShuffleInfoEventsMap = Maps.newConcurrentMap();
+
+ this.storageType = conf.get(RssTezConfig.RSS_STORAGE_TYPE,
RssTezConfig.RSS_STORAGE_TYPE_DEFAULT_VALUE);
+ String readBufferSize = conf.get(RssTezConfig.RSS_CLIENT_READ_BUFFER_SIZE,
+ RssTezConfig.RSS_CLIENT_READ_BUFFER_SIZE_DEFAULT_VALUE);
+ this.readBufferSize = (int)
UnitConverter.byteStringAsBytes(readBufferSize);
+ this.partitionNumPerRange =
conf.getInt(RssTezConfig.RSS_PARTITION_NUM_PER_RANGE,
+ RssTezConfig.RSS_PARTITION_NUM_PER_RANGE_DEFAULT_VALUE);
+
+ LOG.info("RSSShuffleScheduler running for sourceVertex: "
+ + inputContext.getSourceVertexName() + " with configuration: "
+ + "maxFetchFailuresBeforeReporting=" +
maxFetchFailuresBeforeReporting
+ + ", reportReadErrorImmediately=" + reportReadErrorImmediately
+ + ", maxFailedUniqueFetches=" + maxFailedUniqueFetches
+ + ", abortFailureLimit=" + abortFailureLimit
+ + ", maxTaskOutputAtOnce=" + maxTaskOutputAtOnce
+ + ", numFetchers=" + numFetchers
+ + ", hostFailureFraction=" + hostFailureFraction
+ + ", minFailurePerHost=" + minFailurePerHost
+ + ", maxAllowedFailedFetchFraction=" +
maxAllowedFailedFetchFraction
+ + ", maxStallTimeFraction=" + maxStallTimeFraction
+ + ", minReqProgressFraction=" + minReqProgressFraction
+ + ", checkFailedFetchSinceLastCompletion=" +
checkFailedFetchSinceLastCompletion
+ + ", storyType=" + storageType + ", readBufferSize=" +
this.readBufferSize
+ + ", partitionNumPerRange=" + partitionNumPerRange);
+ }
+
+ @Override
+ public void start() throws Exception {
+ int shuffleId = InputContextUtils.computeShuffleId(this.inputContext);
+ TezTaskAttemptID tezTaskAttemptID =
InputContextUtils.getTezTaskAttemptID(this.inputContext);
+ this.partitionToServers = UmbilicalUtils.requestShuffleServer(
+ inputContext.getApplicationId(), conf, tezTaskAttemptID, shuffleId);
+
+ shuffleSchedulerThread = Thread.currentThread();
+ RssShuffleSchedulerCallable rssShuffleSchedulerCallable = new
RssShuffleSchedulerCallable();
+ rssShuffleSchedulerCallable.call();
+ }
+
+ @Override
+ @SuppressFBWarnings("NN_NAKED_NOTIFY")
+ public void close() {
+ try {
+ if (!isShutdown.getAndSet(true)) {
+ try {
+ logProgress();
+ } catch (Exception e) {
+ LOG.warn("Failed log progress while closing, ignoring and continuing
shutdown. Message={}",
+ e.getMessage());
+ }
+
+ // Notify and interrupt the waiting scheduler thread
+ synchronized (this) {
+ notifyAll();
+ }
+ // Interrupt the ShuffleScheduler thread only if the close is invoked
by another thread.
+ // If this is invoked on the same thread, then the shuffleRunner has
already complete, and there's
+ // no point interrupting it.
+ // The interrupt is needed to unblock any merges or waits which may be
happening, so that the thread can
+ // exit.
+ if (shuffleSchedulerThread != null &&
!Thread.currentThread().equals(shuffleSchedulerThread)) {
+ shuffleSchedulerThread.interrupt();
+ }
+
+ // Interrupt the fetchers.
+ for (RssTezShuffleDataFetcher fetcher : rssRunningFetchers) {
+ try {
+ fetcher.shutDown();
+ } catch (Exception e) {
+ LOG.warn(
+ "Error while shutting down fetcher. Ignoring and
continuing shutdown. Message={}",
+ e.getMessage());
+ }
+ }
+
+ // Kill the Referee thread.
+ try {
+ referee.interrupt();
+ referee.join();
+ } catch (InterruptedException e) {
+ LOG.warn(
+ "Interrupted while shutting down referee. Ignoring and
continuing shutdown");
+ Thread.currentThread().interrupt();
+ } catch (Exception e) {
+ LOG.warn(
+ "Error while shutting down referee. Ignoring and continuing
shutdown. Message={}",
+ e.getMessage());
+ }
+ }
+ } finally {
+ long startTime = System.currentTimeMillis();
+ if (!fetcherExecutor.isShutdown()) {
+ // Ensure that fetchers respond to cancel request.
+ fetcherExecutor.shutdownNow();
+ }
+ long endTime = System.currentTimeMillis();
+ LOG.info("Shutting down fetchers for input: {}, shutdown timetaken: {}
ms, "
+ + "hasFetcherExecutorStopped: {}", srcNameTrimmed,
+ (endTime - startTime), hasFetcherExecutorStopped());
+ }
+ }
+
+ @VisibleForTesting
+ @Override
+ boolean hasFetcherExecutorStopped() {
+ return fetcherExecutor.isShutdown();
+ }
+
+ @VisibleForTesting
+ @Override
+ public boolean isShutdown() {
+ return isShutdown.get();
+ }
+
+ @Override
+ protected synchronized void updateEventReceivedTime() {
+ long relativeTime = System.currentTimeMillis() - startTime;
+ if (firstEventReceived.getValue() == 0) {
+ firstEventReceived.setValue(relativeTime);
+ lastEventReceived.setValue(relativeTime);
+ return;
+ }
+ lastEventReceived.setValue(relativeTime);
+ }
+
+ /**
+ * Placeholder for tracking shuffle events in case we get multiple spills
info for the same
+ * attempt.
+ */
+ static class ShuffleEventInfo {
+ BitSet eventsProcessed;
+ int finalEventId = -1; //0 indexed
+ int attemptNum;
+ String id;
+ boolean scheduledForDownload; // whether chunks got scheduled for download
(getMapHost)
+
+ ShuffleEventInfo(InputAttemptIdentifier input) {
+ this.id = input.getInputIdentifier() + "_" + input.getAttemptNumber();
+ this.eventsProcessed = new BitSet();
+ this.attemptNum = input.getAttemptNumber();
+ }
+
+ void spillProcessed(int spillId) {
+ if (finalEventId != -1) {
+ Preconditions.checkState(eventsProcessed.cardinality() <=
(finalEventId + 1),
+ "Wrong state. eventsProcessed cardinality=" +
eventsProcessed.cardinality() + " "
+ + "finalEventId=" + finalEventId + ", spillId=" +
spillId + ", " + toString());
+ }
+ eventsProcessed.set(spillId);
+ }
+
+ void setFinalEventId(int spillId) {
+ finalEventId = spillId;
+ }
+
+ boolean isDone() {
+ return ((finalEventId != -1) && (finalEventId + 1) ==
eventsProcessed.cardinality());
+ }
+
+ @Override
+ public String toString() {
+ return "[eventsProcessed=" + eventsProcessed + ", finalEventId=" +
finalEventId
+ + ", id=" + id + ", attemptNum=" + attemptNum
+ + ", scheduledForDownload=" + scheduledForDownload + "]";
+ }
+ }
+
+ @Override
+ public synchronized void copySucceeded(InputAttemptIdentifier
srcAttemptIdentifier,
+ MapHost host,
+ long bytesCompressed,
+ long bytesDecompressed,
+ long millis,
+ MapOutput output,
+ boolean isLocalFetch) throws IOException {
+ inputContext.notifyProgress();
+ if (!isInputFinished(srcAttemptIdentifier.getInputIdentifier())) {
+ if (!isLocalFetch) {
+ /**
+ * Reset it only when it is a non-local-disk copy.
+ */
+ failedShufflesSinceLastCompletion = 0;
+ }
+ if (output != null) {
+ failureCounts.remove(srcAttemptIdentifier);
+ if (host != null) {
+ hostFailures.remove(new HostPort(host.getHost(), host.getPort()));
+ }
+
+ output.commit();
+ fetchStatsLogger.logIndividualFetchComplete(millis, bytesCompressed,
bytesDecompressed,
+ output.getType().toString(), srcAttemptIdentifier);
+ if (output.getType() == Type.DISK) {
+ bytesShuffledToDisk.increment(bytesCompressed);
+ } else if (output.getType() == Type.DISK_DIRECT) {
+ bytesShuffledToDiskDirect.increment(bytesCompressed);
+ } else {
+ bytesShuffledToMem.increment(bytesCompressed);
+ }
+ shuffledInputsCounter.increment(1);
+ } else {
+ // Output null implies that a physical input completion is being
+ // registered without needing to fetch data
+ skippedInputCounter.increment(1);
+ }
+
+ /**
+ * In case of pipelined shuffle, it is quite possible that fetchers
pulled the FINAL_UPDATE
+ * spill in advance due to smaller output size. In such scenarios, we
need to wait until
+ * we retrieve all spill details to claim success.
+ */
+ if (!srcAttemptIdentifier.canRetrieveInputInChunks()) {
+ remainingMaps.decrementAndGet();
+ setInputFinished(srcAttemptIdentifier.getInputIdentifier());
+ numFetchedSpills++;
+ } else {
+ int inputIdentifier = srcAttemptIdentifier.getInputIdentifier();
+ //Allow only one task attempt to proceed.
+ if (!validateInputAttemptForPipelinedShuffle(srcAttemptIdentifier)) {
+ return;
+ }
+
+ ShuffleEventInfo eventInfo =
pipelinedShuffleInfoEventsMap.get(inputIdentifier);
+
+ //Possible that Shuffle event handler invoked this, due to empty
partitions
+ if (eventInfo == null && output == null) {
+ eventInfo = new ShuffleEventInfo(srcAttemptIdentifier);
+ pipelinedShuffleInfoEventsMap.put(inputIdentifier, eventInfo);
+ }
+
+ assert (eventInfo != null);
+ eventInfo.spillProcessed(srcAttemptIdentifier.getSpillEventId());
+ numFetchedSpills++;
+
+ if (srcAttemptIdentifier.getFetchTypeInfo() ==
InputAttemptIdentifier.SPILL_INFO.FINAL_UPDATE) {
+ eventInfo.setFinalEventId(srcAttemptIdentifier.getSpillEventId());
+ }
+
+ //check if we downloaded all spills pertaining to this
InputAttemptIdentifier
+ if (eventInfo.isDone()) {
+ remainingMaps.decrementAndGet();
+ setInputFinished(inputIdentifier);
+ pipelinedShuffleInfoEventsMap.remove(inputIdentifier);
+ if (LOG.isTraceEnabled()) {
+ LOG.trace("Removing : " + srcAttemptIdentifier + ", pending: " +
pipelinedShuffleInfoEventsMap);
+ }
+ }
+
+ if (LOG.isTraceEnabled()) {
+ LOG.trace("eventInfo " + eventInfo.toString());
+ }
+ }
+
+ if (remainingMaps.get() == 0) {
+ notifyAll(); // Notify the getHost() method.
+ LOG.info("All inputs fetched for input vertex : " +
inputContext.getSourceVertexName());
+ }
+
+ // update the status
+ lastProgressTime = System.currentTimeMillis();
+ totalBytesShuffledTillNow += bytesCompressed;
+ logProgress();
+ reduceShuffleBytes.increment(bytesCompressed);
+ reduceBytesDecompressed.increment(bytesDecompressed);
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("src task: "
+ + TezRuntimeUtils.getTaskAttemptIdentifier(
+ inputContext.getSourceVertexName(),
srcAttemptIdentifier.getInputIdentifier(),
+ srcAttemptIdentifier.getAttemptNumber()) + " done");
+ }
+ } else {
+ // input is already finished. duplicate fetch.
+ LOG.warn("Duplicate fetch of input no longer needs to be fetched: " +
srcAttemptIdentifier);
+ // free the resource - specially memory
+
+ // If the src does not generate data, output will be null.
+ if (output != null) {
+ output.abort();
+ }
+ }
+ // NEWTEZ Should this be releasing the output, if not committed ? Possible
memory leak in case of speculation.
+ }
+
+ private boolean
validateInputAttemptForPipelinedShuffle(InputAttemptIdentifier input) {
+ // For pipelined shuffle.
+ // TEZ-2132 for error handling. As of now, fail fast if there is a
different attempt
+ if (input.canRetrieveInputInChunks()) {
+ ShuffleEventInfo eventInfo =
pipelinedShuffleInfoEventsMap.get(input.getInputIdentifier());
+ if (eventInfo != null && input.getAttemptNumber() !=
eventInfo.attemptNum) {
+ /*
+ * Check if current attempt has been scheduled for download.
+ * e.g currentAttemptNum=0, eventsProcessed={}, newAttemptNum=1
+ * If nothing is scheduled in current attempt and no events are
processed
+ * (i.e copySucceeded), we can ignore current attempt and start
processing the new
+ * attempt (e.g LLAP).
+ */
+ if (eventInfo.scheduledForDownload ||
!eventInfo.eventsProcessed.isEmpty()) {
+ IOException exception = new IOException("Previous event already got
scheduled for "
+ + input + ". Previous attempt's data could have been already
merged "
+ + "to memory/disk outputs. Killing (self) this task early."
+ + " currentAttemptNum=" + eventInfo.attemptNum
+ + ", eventsProcessed=" + eventInfo.eventsProcessed
+ + ", scheduledForDownload=" + eventInfo.scheduledForDownload
+ + ", newAttemptNum=" + input.getAttemptNumber());
+ String message = "Killing self as previous attempt data could have
been consumed";
+ killSelf(exception, message);
+ return false;
+ }
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Ignoring current attempt=" + eventInfo.attemptNum + "
with eventInfo="
+ + eventInfo.toString() + "and processing new attempt=" +
input.getAttemptNumber());
+ }
+ }
+ if (eventInfo == null) {
+ pipelinedShuffleInfoEventsMap.put(input.getInputIdentifier(), new
ShuffleEventInfo(input));
+ }
+ }
+ return true;
+ }
+
+ @VisibleForTesting
+ @Override
+ void killSelf(Exception exception, String message) {
+ LOG.error(message, exception);
+ exceptionReporter.killSelf(exception, message);
+ }
+
+ private final AtomicInteger nextProgressLineEventCount = new
AtomicInteger(0);
+
+ private void logProgress() {
+ int inputsDone = numInputs - remainingMaps.get();
+ if (inputsDone > nextProgressLineEventCount.get() || inputsDone ==
numInputs || isShutdown.get()) {
+ nextProgressLineEventCount.addAndGet(50);
+ double mbs = (double) totalBytesShuffledTillNow / (1024 * 1024);
+ long secsSinceStart = (System.currentTimeMillis() - startTime) / 1000 +
1;
+
+ double transferRate = mbs / secsSinceStart;
+ LOG.info("copy(" + inputsDone + " (spillsFetched=" + numFetchedSpills +
") of " + numInputs
+ + ". Transfer rate
(CumulativeDataFetched/TimeSinceInputStarted)) "
+ + mbpsFormat.format(transferRate) + " MB/s)");
+ }
+ }
+
+ @Override
+ public synchronized void copyFailed(InputAttemptIdentifier srcAttempt,
MapHost host, boolean readError,
+ boolean connectError, boolean isLocalFetch) {
+ failedShuffleCounter.increment(1);
+ inputContext.notifyProgress();
+ int failures = incrementAndGetFailureAttempt(srcAttempt);
+
+ if (!isLocalFetch) {
+ /**
+ * Track the number of failures that has happened since last completion.
+ * This gets reset on a successful copy.
+ */
+ failedShufflesSinceLastCompletion++;
+ }
+
+ /**
+ * Inform AM:
+ * - In case of read/connect error
+ * - In case attempt failures exceed threshold of
+ * maxFetchFailuresBeforeReporting (5)
+ * Bail-out if needed:
+ * - Check whether individual attempt crossed failure threshold limits
+ * - Check overall shuffle health. Bail out if needed.*
+ */
+
+ //TEZ-2890
+ boolean shouldInformAM =
+ (reportReadErrorImmediately && (readError || connectError))
+ || ((failures % maxFetchFailuresBeforeReporting) == 0);
+
+ if (shouldInformAM) {
+ //Inform AM. In case producer needs to be restarted, it is handled at AM.
+ informAM(srcAttempt);
+ }
+
+ //Restart consumer in case shuffle is not healthy
+ if (!isShuffleHealthy(srcAttempt)) {
+ return;
+ }
+
+ penalizeHost(host, failures);
+ }
+
+ private boolean isAbortLimitExceeedFor(InputAttemptIdentifier srcAttempt) {
+ int attemptFailures = getFailureCount(srcAttempt);
+ if (attemptFailures >= abortFailureLimit) {
+ // This task has seen too many fetch failures - report it as failed. The
+ // AM may retry it if max failures has not been reached.
+
+ // Between the task and the AM - someone needs to determine who is at
+ // fault. If there's enough errors seen on the task, before the AM
informs
+ // it about source failure, the task considers itself to have failed and
+ // allows the AM to re-schedule it.
+ String errorMsg = "Failed " + attemptFailures + " times trying to "
+ + "download from " + TezRuntimeUtils.getTaskAttemptIdentifier(
+ inputContext.getSourceVertexName(),
+ srcAttempt.getInputIdentifier(),
+ srcAttempt.getAttemptNumber()) + ". threshold=" +
abortFailureLimit;
+ IOException ioe = new IOException(errorMsg);
+ // Shuffle knows how to deal with failures post shutdown via the
onFailure hook
+ exceptionReporter.reportException(ioe);
+ return true;
+ }
+ return false;
+ }
+
+ private void penalizeHost(MapHost host, int failures) {
+ host.penalize();
+
+ HostPort hostPort = new HostPort(host.getHost(), host.getPort());
+ // TEZ-922 hostFailures isn't really used for anything apart from
+ // hasFailedAcrossNodes().Factor it into error
+ // reporting / potential blacklisting of hosts.
+ if (hostFailures.containsKey(hostPort)) {
+ IntWritable x = hostFailures.get(hostPort);
+ x.set(x.get() + 1);
+ } else {
+ hostFailures.put(hostPort, new IntWritable(1));
+ }
+
+ long delay = (long) (INITIAL_PENALTY * Math.pow(PENALTY_GROWTH_RATE,
failures));
+ penalties.add(new Penalty(host, delay));
+ }
+
+ private int getFailureCount(InputAttemptIdentifier srcAttempt) {
+ IntWritable failureCount = failureCounts.get(srcAttempt);
+ return (failureCount == null) ? 0 : failureCount.get();
+ }
+
+ private int incrementAndGetFailureAttempt(InputAttemptIdentifier srcAttempt)
{
+ int failures = 1;
+ if (failureCounts.containsKey(srcAttempt)) {
+ IntWritable x = failureCounts.get(srcAttempt);
+ x.set(x.get() + 1);
+ failures = x.get();
+ } else {
+ failureCounts.put(srcAttempt, new IntWritable(1));
+ }
+ return failures;
+ }
+
+ @Override
+ public void reportLocalError(IOException ioe) {
+ LOG.error(srcNameTrimmed + ": " + "Shuffle failed : caused by local
error", ioe);
+ // Shuffle knows how to deal with failures post shutdown via the onFailure
hook
+ exceptionReporter.reportException(ioe);
+ }
+
+ // Notify AM
+ private void informAM(InputAttemptIdentifier srcAttempt) {
+ LOG.info(
+ srcNameTrimmed + ": " + "Reporting fetch failure for
InputIdentifier: "
+ + srcAttempt + " taskAttemptIdentifier: " + TezRuntimeUtils
+
.getTaskAttemptIdentifier(inputContext.getSourceVertexName(),
+ srcAttempt.getInputIdentifier(),
+ srcAttempt.getAttemptNumber()) + " to AM.");
+ List<Event> failedEvents = Lists.newArrayListWithCapacity(1);
+ failedEvents.add(InputReadErrorEvent.create(
+ "Fetch failure for " + TezRuntimeUtils
+
.getTaskAttemptIdentifier(inputContext.getSourceVertexName(),
+ srcAttempt.getInputIdentifier(),
+ srcAttempt.getAttemptNumber()) + " to jobtracker.",
+ srcAttempt.getInputIdentifier(),
+ srcAttempt.getAttemptNumber()));
+
+ inputContext.sendEvents(failedEvents);
+ }
+
+ /**
+ * To determine if failures happened across nodes or not. This will help in
+ * determining whether this task needs to be restarted or source needs to
+ * be restarted.
+ *
+ * @param logContext context info for logging
+ * @return boolean true indicates this task needs to be restarted
+ */
+ private boolean hasFailedAcrossNodes(String logContext) {
+ int numUniqueHosts = uniqueHosts.size();
+ Preconditions.checkArgument(numUniqueHosts > 0, "No values in unique
hosts");
+ int threshold = Math.max(3,
+ (int) Math.ceil(numUniqueHosts * hostFailureFraction));
+ int total = 0;
+ boolean failedAcrossNodes = false;
+ for (HostPort host : uniqueHosts) {
+ IntWritable failures = hostFailures.get(host);
+ if (failures != null && failures.get() > minFailurePerHost) {
+ total++;
+ failedAcrossNodes = (total > (threshold * minFailurePerHost));
+ if (failedAcrossNodes) {
+ break;
+ }
+ }
+ }
+
+ LOG.info(logContext + ", numUniqueHosts=" + numUniqueHosts
+ + ", hostFailureThreshold=" + threshold
+ + ", hostFailuresCount=" + hostFailures.size()
+ + ", hosts crossing threshold=" + total
+ + ", reducerFetchIssues=" + failedAcrossNodes
+ );
+
+ return failedAcrossNodes;
+ }
+
+ private boolean allEventsReceived() {
+ if (!pipelinedShuffleInfoEventsMap.isEmpty()) {
+ return (pipelinedShuffleInfoEventsMap.size() == numInputs);
+ } else {
+ //no pipelining
+ return ((pathToIdentifierMap.size() + skippedInputCounter.getValue()) ==
numInputs);
+ }
+ }
+
+ private boolean isAllInputFetched() {
+ return allEventsReceived() && (successRssPartitionSet.size() >=
allRssPartition.size());
+ }
+
+ /**
+ * Check if consumer needs to be restarted based on total failures w.r.t
+ * completed outputs and based on number of errors that have happened since
+ * last successful completion. Consider into account whether failures have
+ * been seen across different nodes.
+ *
+ * @return true to indicate fetchers are healthy
+ */
+ private boolean isFetcherHealthy(String logContext) {
+ long totalFailures = failedShuffleCounter.getValue();
+ int doneMaps = numInputs - remainingMaps.get();
+
+ boolean fetcherHealthy = true;
+ if (doneMaps > 0) {
+ fetcherHealthy = (((float) totalFailures / (totalFailures + doneMaps)) <
maxAllowedFailedFetchFraction);
+ }
+
+ if (fetcherHealthy) {
+ //Compute this logic only when all events are received
+ if (allEventsReceived()) {
+ if (hostFailureFraction > 0) {
+ boolean failedAcrossNodes = hasFailedAcrossNodes(logContext);
+ if (failedAcrossNodes) {
+ return false; //not healthy
+ }
+ }
+
+ if (checkFailedFetchSinceLastCompletion) {
+ /**
+ * remainingMaps works better instead of pendingHosts in the
+ * following condition because of the way the fetcher reports
failures
+ */
+ if (failedShufflesSinceLastCompletion >= remainingMaps.get() *
minFailurePerHost) {
+ /**
+ * Check if lots of errors are seen after last progress time.
+ *
+ * E.g totalFailures = 20. doneMaps = 320 - 300;
+ * fetcherHealthy = (20/(20+300)) < 0.5. So reducer would be
marked as healthy.
+ * Assume 20 errors happen when downloading the last 20 attempts.
Host failure & individual
+ * attempt failures would keep increasing; but at very slow rate
15 * 180 seconds per
+ * attempt to find out the issue.
+ *
+ * Instead consider the new errors with the pending items to be
fetched.
+ * Assume 21 new errors happened after last progress;
remainingMaps = (320-300) = 20;
+ * (21 / (21 + 20)) > 0.5
+ * So we reset the reducer to unhealthy here (special case)
+ *
+ * In normal conditions (i.e happy path), this wouldn't even cause
any issue as
+ * failedShufflesSinceLastCompletion is reset as soon as we see
successful download.
+ */
+
+ fetcherHealthy = (((float) failedShufflesSinceLastCompletion
+ / (failedShufflesSinceLastCompletion + remainingMaps.get())) <
maxAllowedFailedFetchFraction);
+
+ LOG.info(logContext + ", fetcherHealthy=" + fetcherHealthy
+ + ", failedShufflesSinceLastCompletion="
+ + failedShufflesSinceLastCompletion
+ + ", remainingMaps=" + remainingMaps.get()
+ );
+ }
+ }
+ }
+ }
+ return fetcherHealthy;
+ }
+
+ @Override
+ boolean isShuffleHealthy(InputAttemptIdentifier srcAttempt) {
+
+ if (isAbortLimitExceeedFor(srcAttempt)) {
+ return false;
+ }
+
+ final float MIN_REQUIRED_PROGRESS_PERCENT = minReqProgressFraction;
+ final float MAX_ALLOWED_STALL_TIME_PERCENT = maxStallTimeFraction;
+
+ int doneMaps = numInputs - remainingMaps.get();
+
+ String logContext = "srcAttempt=" + srcAttempt.toString();
+ boolean fetcherHealthy = isFetcherHealthy(logContext);
+
+ // check if the reducer has progressed enough
+ boolean reducerProgressedEnough = (((float)doneMaps / numInputs) >=
MIN_REQUIRED_PROGRESS_PERCENT);
+
+ // check if the reducer is stalled for a long time
+ // duration for which the reducer is stalled
+ int stallDuration = (int)(System.currentTimeMillis() - lastProgressTime);
+
+ // duration for which the reducer ran with progress
+ int shuffleProgressDuration = (int)(lastProgressTime - startTime);
+
+ boolean reducerStalled = (shuffleProgressDuration > 0) &&
(((float)stallDuration / shuffleProgressDuration)
+ >= MAX_ALLOWED_STALL_TIME_PERCENT);
+
+ // kill if not healthy and has insufficient progress
+ if ((failureCounts.size() >= maxFailedUniqueFetches ||
failureCounts.size() == (numInputs - doneMaps))
+ && !fetcherHealthy && (!reducerProgressedEnough ||
reducerStalled)) {
+ String errorMsg = (srcNameTrimmed + ": "
+ + "Shuffle failed with too many fetch failures and insufficient
progress!"
+ + "failureCounts=" + failureCounts.size()
+ + ", pendingInputs=" + (numInputs - doneMaps)
+ + ", fetcherHealthy=" + fetcherHealthy
+ + ", reducerProgressedEnough=" + reducerProgressedEnough
+ + ", reducerStalled=" + reducerStalled);
+ LOG.error(errorMsg);
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Host failures=" + hostFailures.keySet());
+ }
+ // Shuffle knows how to deal with failures post shutdown via the
onFailure hook
+ exceptionReporter.reportException(new IOException(errorMsg));
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public synchronized void addKnownMapOutput(String inputHostName, int port,
int partitionId,
+ CompositeInputAttemptIdentifier srcAttempt) {
+
+ LOG.info("AddKnownMapOutput thread:{}, obj:{}, RssShuffleScheduler,
addKnownMapOutput, inputHostName length:{}, "
+ + "port:{}, partitionId:{}, srcAttempt:{}, inputHostName:{}",
+ Thread.currentThread().getName(), this, inputHostName.length(), port,
partitionId, srcAttempt, inputHostName);
+
+ allRssPartition.add(partitionId);
+ if (!partitionIdToSuccessMapTaskAttempts.containsKey(partitionId)) {
+ partitionIdToSuccessMapTaskAttempts.put(partitionId, new HashSet<>());
+ }
+ partitionIdToSuccessMapTaskAttempts.get(partitionId).add(srcAttempt);
+
+ uniqueHosts.add(new HostPort(inputHostName, port));
+ HostPortPartition identifier = new HostPortPartition(inputHostName, port,
partitionId);
+
+ MapHost host = mapLocations.get(identifier);
+ if (host == null) {
+ host = new MapHost(inputHostName, port, partitionId,
srcAttempt.getInputIdentifierCount());
+ mapLocations.put(identifier, host);
+ }
+
+ //Allow only one task attempt to proceed.
+ if (!validateInputAttemptForPipelinedShuffle(srcAttempt)) {
+ return;
+ }
+
+ host.addKnownMap(srcAttempt);
+ for (int i = 0; i < srcAttempt.getInputIdentifierCount(); i++) {
+ PathPartition pathPartition = new
PathPartition(srcAttempt.getPathComponent(), partitionId + i);
+ pathToIdentifierMap.put(pathPartition, srcAttempt.expand(i));
+ }
+
+ // Mark the host as pending
+ if (host.getState() == MapHost.State.PENDING) {
+ pendingHosts.add(host);
+ notifyAll();
+ }
+ }
+
+ @Override
+ public void obsoleteInput(InputAttemptIdentifier srcAttempt) {
+ // The incoming srcAttempt does not contain a path component.
+ LOG.info(srcNameTrimmed + ": " + "Adding obsolete input: " + srcAttempt);
+ ShuffleEventInfo eventInfo =
pipelinedShuffleInfoEventsMap.get(srcAttempt.getInputIdentifier());
+
+ //Pipelined shuffle case (where pipelinedShuffleInfoEventsMap gets
populated).
+ //Fail fast here.
+ if (eventInfo != null) {
+ // In case this we haven't started downloading it, get rid of it.
+ if (eventInfo.eventsProcessed.isEmpty() &&
!eventInfo.scheduledForDownload) {
+ // obsoleted anyways; no point tracking if nothing is started
+ pipelinedShuffleInfoEventsMap.remove(srcAttempt.getInputIdentifier());
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Removing " + eventInfo + " from tracking");
+ }
+ return;
+ }
+ IOException exception = new IOException(srcAttempt + " is marked as
obsoleteInput, but it "
+ + "exists in shuffleInfoEventMap. Some data could have been
already merged "
+ + "to memory/disk outputs. Failing the fetch early. eventInfo:"
+ eventInfo.toString());
+ String message = "Got obsolete event. Killing self as attempt's data
could have been consumed";
+ killSelf(exception, message);
+ return;
+ }
+ synchronized (this) {
+ obsoleteInputs.add(srcAttempt);
+ }
+ }
+
+ @Override
+ public synchronized void putBackKnownMapOutput(MapHost host,
InputAttemptIdentifier srcAttempt) {
+ host.addKnownMap(srcAttempt);
+ }
+
+ @Override
+ public synchronized MapHost getHost() throws InterruptedException {
+ while (pendingHosts.isEmpty() && !isAllInputFetched()) {
+ if (LOG.isInfoEnabled()) {
+ LOG.info("RssShuffleScheduler getHost, pendingHosts:{},
remainingMaps:{}, all partition:{}, "
+ + "success partition:{}", pendingHosts.size(),
remainingMaps.get(),
+ allRssPartition.size(), successRssPartitionSet.size());
+ LOG.info("PendingHosts=" + pendingHosts + ",remainingMaps:" +
remainingMaps.get());
+ }
+ waitAndNotifyProgress();
+ }
+
+ if (!pendingHosts.isEmpty()) {
+ MapHost host = null;
+ Iterator<MapHost> iter = pendingHosts.iterator();
+ int numToPick = random.nextInt(pendingHosts.size());
+ for (int i = 0; i <= numToPick; ++i) {
+ host = iter.next();
+ }
+
+ pendingHosts.remove(host);
+ host.markBusy();
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(srcNameTrimmed + ": " + "Assigning " + host + " with " +
host.getNumKnownMapOutputs()
+ + " to " + Thread.currentThread().getName());
+ }
+ shuffleStart.set(System.currentTimeMillis());
+ return host;
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public InputAttemptIdentifier getIdentifierForFetchedOutput(
+ String path, int reduceId) {
+ return pathToIdentifierMap.get(new PathPartition(path, reduceId));
+ }
+
+ private synchronized boolean inputShouldBeConsumed(InputAttemptIdentifier
id) {
+ boolean isInputFinished = false;
+ if (id instanceof CompositeInputAttemptIdentifier) {
+ CompositeInputAttemptIdentifier cid =
(CompositeInputAttemptIdentifier)id;
+ isInputFinished = isInputFinished(cid.getInputIdentifier(),
cid.getInputIdentifier()
+ + cid.getInputIdentifierCount());
+ } else {
+ isInputFinished = isInputFinished(id.getInputIdentifier());
+ }
+ return !obsoleteInputs.contains(id) && !isInputFinished;
+ }
+
+ @Override
+ public synchronized List<InputAttemptIdentifier> getMapsForHost(MapHost
host) {
+ List<InputAttemptIdentifier> origList = host.getAndClearKnownMaps();
+
+ ListMultimap<Integer, InputAttemptIdentifier> dedupedList =
LinkedListMultimap.create();
+
+ Iterator<InputAttemptIdentifier> listItr = origList.iterator();
+ while (listItr.hasNext()) {
+ // we may want to try all versions of the input but with current retry
+ // behavior older ones are likely to be lost and should be ignored.
+ // This may be removed after TEZ-914
+ InputAttemptIdentifier id = listItr.next();
+ if (inputShouldBeConsumed(id)) {
+ Integer inputNumber = Integer.valueOf(id.getInputIdentifier());
+ List<InputAttemptIdentifier> oldIdList = dedupedList.get(inputNumber);
+
+ if (oldIdList == null || oldIdList.isEmpty()) {
+ dedupedList.put(inputNumber, id);
+ continue;
+ }
+
+ //In case of pipelined shuffle, we can have multiple spills. In such
cases, we can have
+ // more than one item in the oldIdList.
+ boolean addIdentifierToList = false;
+ Iterator<InputAttemptIdentifier> oldIdIterator = oldIdList.iterator();
+ while (oldIdIterator.hasNext()) {
+ InputAttemptIdentifier oldId = oldIdIterator.next();
+
+ //no need to add if spill ids are same
+ if (id.canRetrieveInputInChunks()) {
+ if (oldId.getSpillEventId() == id.getSpillEventId()) {
+ // need to handle deterministic spills later.
+ addIdentifierToList = false;
+ continue;
+ } else if (oldId.getAttemptNumber() == id.getAttemptNumber()) {
+ //but with different spill id.
+ addIdentifierToList = true;
+ break;
+ }
+ }
+
+ //if its from different attempt, take the latest attempt
+ if (oldId.getAttemptNumber() < id.getAttemptNumber()) {
+ //remove existing identifier
+ oldIdIterator.remove();
+ LOG.warn("Old Src for InputIndex: " + inputNumber + " with
attemptNumber: "
+ + oldId.getAttemptNumber()
+ + " was not determined to be invalid. Ignoring it for now
in favour of "
+ + id.getAttemptNumber());
+ addIdentifierToList = true;
+ break;
+ }
+ }
+ if (addIdentifierToList) {
+ dedupedList.put(inputNumber, id);
+ }
+ } else {
+ LOG.info("Ignoring finished or obsolete source: " + id);
+ }
+ }
+
+ // Compute the final list, limited by NUM_FETCHERS_AT_ONCE
+ List<InputAttemptIdentifier> result = new
ArrayList<InputAttemptIdentifier>();
+ int includedMaps = 0;
+ int totalSize = dedupedList.size();
+
+ for (Integer inputIndex : dedupedList.keySet()) {
+ List<InputAttemptIdentifier> attemptIdentifiers =
dedupedList.get(inputIndex);
+ for (InputAttemptIdentifier inputAttemptIdentifier : attemptIdentifiers)
{
+ if (includedMaps++ >= maxTaskOutputAtOnce) {
+ host.addKnownMap(inputAttemptIdentifier);
+ } else {
+ if (inputAttemptIdentifier.canRetrieveInputInChunks()) {
+ ShuffleEventInfo shuffleEventInfo =
+
pipelinedShuffleInfoEventsMap.get(inputAttemptIdentifier.getInputIdentifier());
+ if (shuffleEventInfo != null) {
+ shuffleEventInfo.scheduledForDownload = true;
+ }
+ }
+ result.add(inputAttemptIdentifier);
+ }
+ }
+ }
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("assigned " + includedMaps + " of " + totalSize + " to "
+ + host + " to " + Thread.currentThread().getName());
+ }
+ return result;
+ }
+
+ @Override
+ public synchronized void freeHost(MapHost host) {
+ if (host.getState() != MapHost.State.PENALIZED) {
+ if (host.markAvailable() == MapHost.State.PENDING) {
+ pendingHosts.add(host);
+ notifyAll();
+ }
+ }
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(host + " freed by " + Thread.currentThread().getName() + " in "
+ + (System.currentTimeMillis() - shuffleStart.get()) + "ms");
+ }
+ }
+
+ @Override
+ public synchronized void resetKnownMaps() {
+ mapLocations.clear();
+ obsoleteInputs.clear();
+ pendingHosts.clear();
+ pathToIdentifierMap.clear();
+ }
+
+ /**
+ * Utility method to check if the Shuffle data fetch is complete.
+ * @return true if complete
+ */
+ @Override
+ public synchronized boolean isDone() {
+ return remainingMaps.get() == 0;
+ }
+
+ /**
+ * A structure that records the penalty for a host.
+ */
+ private static class Penalty implements Delayed {
+ MapHost host;
+ private long endTime;
+
+ Penalty(MapHost host, long delay) {
+ this.host = host;
+ this.endTime = System.currentTimeMillis() + delay;
+ }
+
+ @Override
+ public long getDelay(TimeUnit unit) {
+ long remainingTime = endTime - System.currentTimeMillis();
+ return unit.convert(remainingTime, TimeUnit.MILLISECONDS);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Penalty penalty = (Penalty) o;
+ return endTime == penalty.endTime && Objects.equals(host, penalty.host);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(host, endTime);
+ }
+
+ @Override
+ public int compareTo(Delayed o) {
+ long other = ((Penalty) o).endTime;
+ return endTime == other ? 0 : (endTime < other ? -1 : 1);
+ }
+
+ }
+
+ /**
+ * A thread that takes hosts off of the penalty list when the timer expires.
+ */
+ private class Referee extends Thread {
+ Referee() {
+ setName("ShufflePenaltyReferee {"
+ +
TezUtilsInternal.cleanVertexName(inputContext.getSourceVertexName()) + "}");
+ setDaemon(true);
+ }
+
+ @Override
+ public void run() {
+ try {
+ while (!isShutdown.get()) {
+ // take the first host that has an expired penalty
+ MapHost host = penalties.take().host;
+ synchronized (RssShuffleScheduler.this) {
+ if (host.markAvailable() == MapHost.State.PENDING) {
+ pendingHosts.add(host);
+ RssShuffleScheduler.this.notifyAll();
+ }
+ }
+ }
+ } catch (InterruptedException ie) {
+ Thread.currentThread().interrupt();
+ // This handles shutdown of the entire fetch / merge process.
+ } catch (Throwable t) {
+ // Shuffle knows how to deal with failures post shutdown via the
onFailure hook
+ exceptionReporter.reportException(t);
+ }
+ }
+ }
+
+
+ @Override
+ void setInputFinished(int inputIndex) {
+ synchronized (finishedMaps) {
+ finishedMaps.set(inputIndex, true);
+ }
+ }
+
+ @Override
+ boolean isInputFinished(int inputIndex) {
+ synchronized (finishedMaps) {
+ return finishedMaps.get(inputIndex);
+ }
+ }
+
+ @Override
+ boolean isInputFinished(int inputIndex, int inputEnd) {
+ synchronized (finishedMaps) {
+ return finishedMaps.nextClearBit(inputIndex) > inputEnd;
+ }
+ }
+
+ private class RssShuffleSchedulerCallable extends CallableWithNdc<Void> {
+
+ @Override
+ protected Void callInternal() throws IOException, InterruptedException,
TezException, RssException {
+ while (!isShutdown.get() && !isAllInputFetched()) {
+ LOG.info("Now allEventsReceived: " + allEventsReceived());
+
+ synchronized (RssShuffleScheduler.this) {
+ while (!allEventsReceived()
+ || ((rssRunningFetchers.size() >= numFetchers ||
pendingHosts.isEmpty()) && !isAllInputFetched())) {
+ try {
+ LOG.info("RssShuffleSchedulerCallable, wait pending hosts,
pendingHosts:{}.", pendingHosts.isEmpty());
+ waitAndNotifyProgress();
+ } catch (InterruptedException e) {
+ if (isShutdown.get()) {
+ LOG.info(srcNameTrimmed + ": " + "Interrupted while waiting
for fetchers to complete"
+ + "and hasBeenShutdown. Breaking out of
ShuffleSchedulerCallable loop");
+ Thread.currentThread().interrupt();
+ break;
+ } else {
+ throw e;
+ }
+ }
+ }
+ }
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(srcNameTrimmed + ": " + "NumCompletedInputs: {}" +
(numInputs - remainingMaps.get()));
+ }
+ // Ensure there's memory available before scheduling the next Fetcher.
+ try {
+ // If merge is on, block
+ mergeManager.waitForInMemoryMerge();
+ // In case usedMemory > memorylimit, wait until some memory is
released
+ mergeManager.waitForShuffleToMergeMemory();
+ } catch (InterruptedException e) {
+ if (isShutdown.get()) {
+ LOG.info(srcNameTrimmed + ": Interrupted while waiting for merge
to complete and hasBeenShutdown. "
+ + "Breaking out of ShuffleSchedulerCallable loop");
+ Thread.currentThread().interrupt();
+ break;
+ } else {
+ throw e;
+ }
+ }
+
+ if (!isShutdown.get() && !isAllInputFetched()) {
+ synchronized (RssShuffleScheduler.this) {
+ int numFetchersToRun = numFetchers - rssRunningFetchers.size();
+ int count = 0;
+ while (count < numFetchersToRun && !isShutdown.get() &&
!isAllInputFetched()) {
+ MapHost mapHost;
+ try {
+ mapHost = getHost(); // Leads to a wait.
+ } catch (InterruptedException e) {
+ if (isShutdown.get()) {
+ LOG.info(srcNameTrimmed + ": Interrupted while waiting for
host and hasBeenShutdown. "
+ + "Breaking out of ShuffleSchedulerCallable loop");
+ Thread.currentThread().interrupt();
+ break;
+ } else {
+ throw e;
+ }
+ }
+ if (mapHost == null) {
+ LOG.info("Get null mapHost and break out.");
+ break; // Check for the exit condition.
+ }
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(srcNameTrimmed + ": " + "Processing pending host: "
+ mapHost.toString());
+ }
+ if (!isShutdown.get()) {
+ count++;
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(srcNameTrimmed + ": " + "Scheduling fetch for
inputHost: {}",
+ mapHost.getHostIdentifier() + ":" +
mapHost.getPartitionId());
+ }
+
+ if (isFirstRssPartitionFetch(mapHost)) {
+ int partitionId = mapHost.getPartitionId();
+ RssTezShuffleDataFetcher rssTezShuffleDataFetcher =
constructRssFetcherForPartition(mapHost,
+ partitionToServers.get(partitionId));
+
+ rssRunningFetchers.add(rssTezShuffleDataFetcher);
+ ListenableFuture<Void> future =
fetcherExecutor.submit(rssTezShuffleDataFetcher);
+ Futures.addCallback(future, new
FetchFutureCallback(rssTezShuffleDataFetcher),
+ MoreExecutors.directExecutor());
+ } else {
+ for (int i = 0; i < mapHost.getAndClearKnownMaps().size();
i++) {
+ remainingMaps.decrementAndGet();
+ }
+ LOG.info("Partition was fetched, remainingMaps desc, now
value:{}", remainingMaps.get());
+ }
+ }
+ }
+ }
+ }
+ }
+ LOG.info("Shutting down FetchScheduler for input: {}, wasInterrupted={}",
+ srcNameTrimmed, Thread.currentThread().isInterrupted());
+ if (!fetcherExecutor.isShutdown()) {
+ fetcherExecutor.shutdownNow();
+ }
+ return null;
+ }
+ }
+
+ private synchronized boolean isFirstRssPartitionFetch(MapHost mapHost) {
+ Integer partitionId = mapHost.getPartitionId();
+ LOG.info("Check isFirstRssPartitionFetch, mapHost:{},partitionId:{}",
mapHost, partitionId);
+
+ if (runningRssPartitionMap.containsKey(partitionId) ||
successRssPartitionSet.contains(partitionId)) {
+ return false;
+ }
+ runningRssPartitionMap.put(partitionId, mapHost);
+ return true;
+ }
+
+ private JobConf getRemoteConf() {
+ return new JobConf(conf);
+ }
+
+ private synchronized void waitAndNotifyProgress() throws
InterruptedException {
+ inputContext.notifyProgress();
+ wait(1000);
+ }
+
+
+ @VisibleForTesting
+ private RssTezShuffleDataFetcher constructRssFetcherForPartition(MapHost
mapHost,
+ List<ShuffleServerInfo> shuffleServerInfoList) throws RssException {
+ Set<ShuffleServerInfo> shuffleServerInfoSet = new
HashSet<>(shuffleServerInfoList);
+ LOG.info("ConstructRssFetcherForPartition, shuffleServerInfoSet: {}",
shuffleServerInfoSet);
+
+ Optional<InputAttemptIdentifier> attempt =
partitionIdToSuccessMapTaskAttempts.get(
+ mapHost.getPartitionId()).stream().findFirst();
+ LOG.info("ConstructRssFetcherForPartition, partitionId:{}, take a
attempt:{}", mapHost.getPartitionId(), attempt);
+
+ ShuffleWriteClient writeClient = RssTezUtils.createShuffleClient(conf);
+ String clientType = "";
+ int shuffleId = InputContextUtils.computeShuffleId(inputContext);
+ Roaring64NavigableMap blockIdBitmap = writeClient.getShuffleResult(
+ clientType, shuffleServerInfoSet, applicationId, shuffleId,
mapHost.getPartitionId());
+ writeClient.close();
+
+ int appAttemptId = IdUtils.getAppAttemptId();
+ Roaring64NavigableMap taskIdBitmap = RssTezUtils.fetchAllRssTaskIds(
+ partitionIdToSuccessMapTaskAttempts.get(mapHost.getPartitionId()),
this.numInputs,
+ appAttemptId);
+
+ LOG.info("In reduce: {}, RSS Tez client has fetched blockIds and taskIds
successfully, partitionId:{}.",
+ inputContext.getTaskVertexName(), mapHost.getPartitionId());
+
+ // start fetcher to fetch blocks from RSS servers
+ if (!taskIdBitmap.isEmpty()) {
+ LOG.info("In reduce: " + inputContext.getTaskVertexName()
+ + ", Rss Tez client starts to fetch blocks from RSS server");
+ JobConf readerJobConf = getRemoteConf();
+
+ int partitionNum = partitionToServers.size();
+ boolean expectedTaskIdsBitmapFilterEnable = shuffleServerInfoSet.size()
> 1;
+
+ CreateShuffleReadClientRequest request = new
CreateShuffleReadClientRequest(
+ applicationId,
+ shuffleId,
+ mapHost.getPartitionId(),
+ basePath,
+ partitionNumPerRange,
+ partitionNum,
+ blockIdBitmap,
+ taskIdBitmap,
+ shuffleServerInfoList,
+ readerJobConf,
+ new TezIdHelper(),
+ expectedTaskIdsBitmapFilterEnable);
+
+ ShuffleReadClient shuffleReadClient =
ShuffleClientFactory.getInstance().createShuffleReadClient(request);
+ RssTezShuffleDataFetcher fetcher = new RssTezShuffleDataFetcher(
+
partitionIdToSuccessMapTaskAttempts.get(mapHost.getPartitionId()).iterator().next(),
+ mapHost.getPartitionId(),
+ mergeManager, inputContext.getCounters(), shuffleReadClient,
blockIdBitmap.getLongCardinality(),
+ RssTezConfig.toRssConf(conf), exceptionReporter);
+ return fetcher;
+ }
+
+ throw new RssException("Construct rss fetcher partition task failed");
+ }
+
+ @VisibleForTesting
+ @Override
+ FetcherOrderedGrouped constructFetcherForHost(MapHost mapHost) {
+ return new FetcherOrderedGrouped(httpConnectionParams,
RssShuffleScheduler.this, allocator,
+ exceptionReporter, jobTokenSecretManager, ifileReadAhead,
ifileReadAheadLength,
+ codec, conf, localDiskFetchEnabled, localHostname, shufflePort,
srcNameTrimmed, mapHost,
+ ioErrsCounter, wrongLengthErrsCounter, badIdErrsCounter,
wrongMapErrsCounter,
+ connectionErrsCounter, wrongReduceErrsCounter, applicationId,
dagId, asyncHttp, sslShuffle,
+ verifyDiskChecksum, compositeFetch);
+ }
+
+ private class FetchFutureCallback implements FutureCallback<Void> {
+
+ private final RssTezShuffleDataFetcher rssFetcherOrderedGrouped;
+ private final Integer partitionId;
+
+ FetchFutureCallback(RssTezShuffleDataFetcher rssFetcherOrderedGrouped) {
+ this.rssFetcherOrderedGrouped = rssFetcherOrderedGrouped;
+ this.partitionId = rssFetcherOrderedGrouped.getPartitionId();
+ }
+
+ private void doBookKeepingForFetcherComplete() {
+ synchronized (RssShuffleScheduler.this) {
+ rssRunningFetchers.remove(rssFetcherOrderedGrouped);
+ RssShuffleScheduler.this.notifyAll();
+ }
+ }
+
+
+ @Override
+ public void onSuccess(Void result) {
+ rssFetcherOrderedGrouped.shutDown();
+
+ if (isShutdown.get()) {
+ LOG.info(srcNameTrimmed + ": " + "Already shutdown. Ignoring fetch
complete");
+ } else {
+ successRssPartitionSet.add(partitionId);
+ MapHost mapHost = runningRssPartitionMap.remove(partitionId);
+ if (mapHost != null) {
+ for (int i = 0; i < mapHost.getAndClearKnownMaps().size(); i++) {
+ remainingMaps.decrementAndGet();
+ }
+ }
+ doBookKeepingForFetcherComplete();
+ LOG.info("FetchFutureCallback onSuccess, result:{}, success
partitionId:{}, successRssPartitionSet:{}, "
+ + "remainingMaps now value:{}", result,
rssFetcherOrderedGrouped.getPartitionId(),
+ successRssPartitionSet, remainingMaps.get());
+ }
+ }
+
+ @Override
+ public void onFailure(Throwable t) {
+ LOG.error("Failed to fetch.", t);
+ rssFetcherOrderedGrouped.shutDown();
+ if (isShutdown.get()) {
+ LOG.info(srcNameTrimmed + ": " + "Already shutdown. Ignoring fetch
complete");
+ } else {
+ LOG.error(srcNameTrimmed + ": " + "Fetcher failed with error", t);
+ exceptionReporter.reportException(t);
+ doBookKeepingForFetcherComplete();
+ }
+ }
+ }
+}
+
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleSchedulerTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleSchedulerTest.java
new file mode 100644
index 00000000..e64645be
--- /dev/null
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleSchedulerTest.java
@@ -0,0 +1,912 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.orderedgrouped;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.compress.CompressionCodec;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.common.IdUtils;
+import org.apache.tez.common.TezCommonUtils;
+import org.apache.tez.common.TezExecutors;
+import org.apache.tez.common.TezSharedExecutor;
+import org.apache.tez.common.counters.TezCounters;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.runtime.api.ExecutionContext;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.api.impl.ExecutionContextImpl;
+import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
+import org.apache.tez.runtime.library.common.CompositeInputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.atLeast;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class RssShuffleSchedulerTest {
+
+ private TezExecutors sharedExecutor;
+
+ @BeforeEach
+ public void setup() {
+ sharedExecutor = new TezSharedExecutor(new Configuration());
+ }
+
+ @BeforeEach
+ public void cleanup() {
+ sharedExecutor.shutdownNow();
+ }
+
+ @Test
+ /**
+ * Scenario
+ * - reducer has not progressed enough
+ * - reducer becomes unhealthy after some failures
+ * - no of attempts failing exceeds maxFailedUniqueFetches (5)
+ * Expected result
+ * - fail the reducer
+ */
+ @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS)
+ public void testReducerHealth1() throws IOException {
+ try (MockedStatic<IdUtils> idUtils = Mockito.mockStatic(IdUtils.class)) {
+ ApplicationId appId = ApplicationId.newInstance(9999, 72);
+ ApplicationAttemptId appAttemptId =
ApplicationAttemptId.newInstance(appId, 1);
+ idUtils.when(IdUtils::getApplicationAttemptId).thenReturn(appAttemptId);
+ try (MockedStatic<ShuffleUtils> shuffleUtils =
Mockito.mockStatic(ShuffleUtils.class)) {
+ shuffleUtils.when(() ->
ShuffleUtils.deserializeShuffleProviderMetaData(any())).thenReturn(4);
+
+ Configuration conf = new TezConfiguration();
+ testReducerHealth1(conf);
+
conf.setInt(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MIN_FAILURES_PER_HOST,
4000);
+ testReducerHealth1(conf);
+ }
+ }
+ }
+
+ public void testReducerHealth1(Configuration conf) throws IOException {
+ long startTime = System.currentTimeMillis() - 500000;
+ Shuffle shuffle = mock(Shuffle.class);
+ final ShuffleSchedulerForTest scheduler = createScheduler(startTime, 320,
shuffle, conf);
+
+ int totalProducerNodes = 20;
+
+ //Generate 320 events
+ for (int i = 0; i < 320; i++) {
+ CompositeInputAttemptIdentifier inputAttemptIdentifier =
+ new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1);
+ scheduler.addKnownMapOutput("host" + (i % totalProducerNodes),
+ 10000, i, inputAttemptIdentifier);
+ }
+
+ //100 succeeds
+ for (int i = 0; i < 100; i++) {
+ InputAttemptIdentifier inputAttemptIdentifier =
+ new InputAttemptIdentifier(i, 0, "attempt_");
+ MapOutput mapOutput = MapOutput
+ .createMemoryMapOutput(inputAttemptIdentifier,
mock(FetchedInputAllocatorOrderedGrouped.class), 100, false);
+ scheduler.copySucceeded(inputAttemptIdentifier, new MapHost("host" + (i
% totalProducerNodes),
+ 10000, i, 1), 100, 200, startTime + (i * 100), mapOutput, false);
+ }
+
+ //99 fails
+ for (int i = 100; i < 199; i++) {
+ InputAttemptIdentifier inputAttemptIdentifier =
+ new InputAttemptIdentifier(i, 0, "attempt_");
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i %
totalProducerNodes),
+ 10000, i, 1), false, true, false);
+ }
+
+
+ InputAttemptIdentifier inputAttemptIdentifier =
+ new InputAttemptIdentifier(200, 0, "attempt_");
+
+ //Should fail here and report exception as reducer is not healthy
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (200 %
totalProducerNodes),
+ 10000, 200, 1), false, true, false);
+
+ int minFailurePerHost = conf.getInt(
+ TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MIN_FAILURES_PER_HOST,
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_MIN_FAILURES_PER_HOST_DEFAULT);
+
+ if (minFailurePerHost <= 4) {
+ //As per test threshold. Should fail & retrigger shuffle
+ verify(shuffle, atLeast(0)).reportException(any(Throwable.class));
+ } else if (minFailurePerHost > 100) {
+ //host failure is so high that this would not retrigger shuffle
re-execution
+ verify(shuffle, atLeast(1)).reportException(any(Throwable.class));
+ }
+ }
+
+ @Test
+ /**
+ * Scenario
+ * - reducer has progressed enough
+ * - failures start happening after that
+ * - no of attempts failing exceeds maxFailedUniqueFetches (5)
+ * - Has not stalled
+ * Expected result
+ * - Since reducer is not stalled, it should continue without error
+ *
+ * When reducer stalls, wait until enough retries are done and throw
exception
+ *
+ */
+ @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS)
+ public void testReducerHealth2() throws IOException, InterruptedException {
+ try (MockedStatic<IdUtils> idUtils = Mockito.mockStatic(IdUtils.class)) {
+ ApplicationId appId = ApplicationId.newInstance(9999, 72);
+ ApplicationAttemptId appAttemptId =
ApplicationAttemptId.newInstance(appId, 1);
+ idUtils.when(IdUtils::getApplicationAttemptId).thenReturn(appAttemptId);
+ try (MockedStatic<ShuffleUtils> shuffleUtils =
Mockito.mockStatic(ShuffleUtils.class)) {
+ shuffleUtils.when(() ->
ShuffleUtils.deserializeShuffleProviderMetaData(any())).thenReturn(4);
+
+ long startTime = System.currentTimeMillis() - 500000;
+ Shuffle shuffle = mock(Shuffle.class);
+ final ShuffleSchedulerForTest scheduler = createScheduler(startTime,
320, shuffle);
+
+ int totalProducerNodes = 20;
+
+ //Generate 0-200 events
+ for (int i = 0; i < 200; i++) {
+ CompositeInputAttemptIdentifier inputAttemptIdentifier =
+ new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1);
+ scheduler.addKnownMapOutput("host" + (i % totalProducerNodes),
10000, i, inputAttemptIdentifier);
+ }
+ assertEquals(320, scheduler.remainingMaps.get());
+
+ //Generate 200-320 events with empty partitions
+ for (int i = 200; i < 320; i++) {
+ InputAttemptIdentifier inputAttemptIdentifier = new
InputAttemptIdentifier(i, 0, "attempt_");
+ scheduler.copySucceeded(inputAttemptIdentifier, null, 0, 0, 0, null,
true);
+ }
+ //120 are successful. so remaining is 200
+ assertEquals(200, scheduler.remainingMaps.get());
+
+
+ //200 pending to be downloaded. Download 190.
+ for (int i = 0; i < 190; i++) {
+ InputAttemptIdentifier inputAttemptIdentifier =
+ new InputAttemptIdentifier(i, 0, "attempt_");
+ MapOutput mapOutput = MapOutput
+ .createMemoryMapOutput(inputAttemptIdentifier,
mock(FetchedInputAllocatorOrderedGrouped.class),
+ 100, false);
+ scheduler.copySucceeded(inputAttemptIdentifier, new MapHost("host" +
(i % totalProducerNodes),
+ 10000, i, 1), 100, 200, startTime + (i * 100), mapOutput, false);
+ }
+
+ assertEquals(10, scheduler.remainingMaps.get());
+
+ //10 fails
+ for (int i = 190; i < 200; i++) {
+ InputAttemptIdentifier inputAttemptIdentifier = new
InputAttemptIdentifier(i, 0, "attempt_");
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i
% totalProducerNodes),
+ 10000, i, 1), false, true, false);
+ }
+
+ //Shuffle has not stalled. so no issues.
+ verify(scheduler.reporter,
times(0)).reportException(any(Throwable.class));
+
+ //stall shuffle
+ scheduler.lastProgressTime = System.currentTimeMillis() - 250000;
+
+ InputAttemptIdentifier inputAttemptIdentifier =
+ new InputAttemptIdentifier(190, 0, "attempt_");
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (190
% totalProducerNodes),
+ 10000, 190, 1), false, true, false);
+
+ //Even when it is stalled, need (320 - 300 = 20) * 3 = 60 failures
+ verify(scheduler.reporter,
times(0)).reportException(any(Throwable.class));
+
+ assertEquals(11, scheduler.failedShufflesSinceLastCompletion);
+
+ //fail to download 50 more times across attempts
+ for (int i = 190; i < 200; i++) {
+ inputAttemptIdentifier = new InputAttemptIdentifier(i, 0,
"attempt_");
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i
% totalProducerNodes),
+ 10000, i, 1), false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i
% totalProducerNodes),
+ 10000, i, 1), false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i
% totalProducerNodes),
+ 10000, i, 1), false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i
% totalProducerNodes),
+ 10000, i, 1), false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i
% totalProducerNodes),
+ 10000, i, 1), false, true, false);
+ }
+
+ assertEquals(61, scheduler.failedShufflesSinceLastCompletion);
+ assertEquals(10, scheduler.remainingMaps.get());
+
+ verify(shuffle, atLeast(0)).reportException(any(Throwable.class));
+
+ //fail another 30
+ for (int i = 110; i < 120; i++) {
+ inputAttemptIdentifier = new InputAttemptIdentifier(i, 0,
"attempt_");
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i
% totalProducerNodes),
+ 10000, i, 1), false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i
% totalProducerNodes),
+ 10000, i, 1), false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i
% totalProducerNodes),
+ 10000, i, 1), false, true, false);
+ }
+
+ // Should fail now due to fetcherHealthy. (stall has already happened
and
+ // these are the only pending tasks)
+ verify(shuffle, atLeast(1)).reportException(any(Throwable.class));
+ }
+ }
+ }
+
+ @Test
+ /**
+ * Scenario
+ * - reducer has progressed enough
+ * - failures start happening after that in last fetch
+ * - no of attempts failing does not exceed maxFailedUniqueFetches (5)
+ * - Stalled
+ * Expected result
+ * - Since reducer is stalled and if failures haven't happened across
nodes,
+ * it should be fine to proceed. AM would restart source task eventually.
+ *
+ */
+ @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS)
+ public void testReducerHealth3() throws IOException {
+ try (MockedStatic<IdUtils> idUtils = Mockito.mockStatic(IdUtils.class)) {
+ ApplicationId appId = ApplicationId.newInstance(9999, 72);
+ ApplicationAttemptId appAttemptId =
ApplicationAttemptId.newInstance(appId, 1);
+ idUtils.when(IdUtils::getApplicationAttemptId).thenReturn(appAttemptId);
+
+ try (MockedStatic<ShuffleUtils> shuffleUtils =
Mockito.mockStatic(ShuffleUtils.class)) {
+ shuffleUtils.when(() ->
ShuffleUtils.deserializeShuffleProviderMetaData(any())).thenReturn(4);
+
+ long startTime = System.currentTimeMillis() - 500000;
+ Shuffle shuffle = mock(Shuffle.class);
+ final ShuffleSchedulerForTest scheduler = createScheduler(startTime,
320, shuffle);
+
+ int totalProducerNodes = 20;
+
+ //Generate 320 events
+ for (int i = 0; i < 320; i++) {
+ CompositeInputAttemptIdentifier inputAttemptIdentifier =
+ new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1);
+ scheduler.addKnownMapOutput("host" + (i % totalProducerNodes),
10000, i, inputAttemptIdentifier);
+ }
+
+ //319 succeeds
+ for (int i = 0; i < 319; i++) {
+ InputAttemptIdentifier inputAttemptIdentifier = new
InputAttemptIdentifier(i, 0, "attempt_");
+ MapOutput mapOutput = MapOutput
+ .createMemoryMapOutput(inputAttemptIdentifier,
mock(FetchedInputAllocatorOrderedGrouped.class),
+ 100, false);
+ scheduler.copySucceeded(inputAttemptIdentifier, new MapHost("host" +
(i % totalProducerNodes),
+ 10000, i, 1), 100, 200, startTime + (i * 100), mapOutput, false);
+ }
+
+ //1 fails (last fetch)
+ InputAttemptIdentifier inputAttemptIdentifier =
+ new InputAttemptIdentifier(319, 0, "attempt_");
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (319
% totalProducerNodes),
+ 10000, 319, 1), false, true, false);
+
+ //stall the shuffle
+ scheduler.lastProgressTime = System.currentTimeMillis() - 1000000;
+
+ assertEquals(scheduler.remainingMaps.get(), 1);
+
+ //Retry for 3 more times
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (319
% totalProducerNodes),
+ 10000, 319, 1), false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (319
% totalProducerNodes),
+ 10000, 310, 1), false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (319
% totalProducerNodes),
+ 10000, 310, 1), false, true, false);
+
+ // failedShufflesSinceLastCompletion has crossed the limits. Throw
error
+ verify(shuffle, times(0)).reportException(any(Throwable.class));
+ }
+ }
+ }
+
+ @Test
+ /**
+ * Scenario
+ * - reducer has progressed enough
+ * - failures have happened randomly in nodes, but tasks are completed
+ * - failures start happening after that in last fetch
+ * - no of attempts failing does not exceed maxFailedUniqueFetches (5)
+ * - Stalled
+ * Expected result
+ * - reducer is stalled. But since errors are not seen across multiple
+ * nodes, it is left to the AM to retart producer. Do not kill consumer.
+ *
+ */
+ @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS)
+ public void testReducerHealth4() throws IOException {
+ try (MockedStatic<IdUtils> idUtils = Mockito.mockStatic(IdUtils.class)) {
+ ApplicationId appId = ApplicationId.newInstance(9999, 72);
+ ApplicationAttemptId appAttemptId =
ApplicationAttemptId.newInstance(appId, 1);
+ idUtils.when(IdUtils::getApplicationAttemptId).thenReturn(appAttemptId);
+ try (MockedStatic<ShuffleUtils> shuffleUtils =
Mockito.mockStatic(ShuffleUtils.class)) {
+ shuffleUtils.when(() ->
ShuffleUtils.deserializeShuffleProviderMetaData(any())).thenReturn(4);
+
+ long startTime = System.currentTimeMillis() - 500000;
+ Shuffle shuffle = mock(Shuffle.class);
+ final ShuffleSchedulerForTest scheduler = createScheduler(startTime,
320, shuffle);
+
+ int totalProducerNodes = 20;
+
+ //Generate 320 events
+ for (int i = 0; i < 320; i++) {
+ CompositeInputAttemptIdentifier inputAttemptIdentifier =
+ new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1);
+ scheduler.addKnownMapOutput("host" + (i % totalProducerNodes),
+ 10000, i, inputAttemptIdentifier);
+ }
+
+ //Tasks fail in 20% of nodes 3 times, but are able to proceed further
+ for (int i = 0; i < 64; i++) {
+ InputAttemptIdentifier inputAttemptIdentifier = new
InputAttemptIdentifier(i, 0, "attempt_");
+
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i
% totalProducerNodes), 10000, i, 1),
+ false, true, false);
+
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i
% totalProducerNodes), 10000, i, 1),
+ false, true, false);
+
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i
% totalProducerNodes), 10000, i, 1),
+ false, true, false);
+
+ MapOutput mapOutput = MapOutput
+ .createMemoryMapOutput(inputAttemptIdentifier,
mock(FetchedInputAllocatorOrderedGrouped.class),
+ 100, false);
+ scheduler.copySucceeded(inputAttemptIdentifier, new MapHost("host" +
(i % totalProducerNodes),
+ 10000, i, 1), 100, 200, startTime + (i * 100), mapOutput, false);
+ }
+
+ //319 succeeds
+ for (int i = 64; i < 319; i++) {
+ InputAttemptIdentifier inputAttemptIdentifier = new
InputAttemptIdentifier(i, 0, "attempt_");
+ MapOutput mapOutput = MapOutput
+ .createMemoryMapOutput(inputAttemptIdentifier,
mock(FetchedInputAllocatorOrderedGrouped.class),
+ 100, false);
+ scheduler.copySucceeded(inputAttemptIdentifier, new MapHost("host" +
(i % totalProducerNodes),
+ 10000, i, 1), 100, 200, startTime + (i * 100), mapOutput, false);
+ }
+
+ //1 fails (last fetch)
+ InputAttemptIdentifier inputAttemptIdentifier =
+ new InputAttemptIdentifier(319, 0, "attempt_");
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (319
% totalProducerNodes),
+ 10000, 319, 1), false, true, false);
+
+ //stall the shuffle (but within limits)
+ scheduler.lastProgressTime = System.currentTimeMillis() - 100000;
+
+ assertEquals(scheduler.remainingMaps.get(), 1);
+
+ //Retry for 3 more times
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (319
% totalProducerNodes),
+ 10000, 319, 1), false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (319
% totalProducerNodes),
+ 10000, 319, 1), false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (319
% totalProducerNodes),
+ 10000, 319, 1), false, true, false);
+
+ // failedShufflesSinceLastCompletion has crossed the limits. 20% of
other nodes had failures as
+ // well. However, it has failed only in one host. So this should
proceed
+ // until AM decides to restart the producer.
+ verify(shuffle, times(0)).reportException(any(Throwable.class));
+
+ //stall the shuffle (but within limits)
+ scheduler.lastProgressTime = System.currentTimeMillis() - 300000;
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (319
% totalProducerNodes),
+ 10000, 319, 1), false, true, false);
+ verify(shuffle, times(1)).reportException(any(Throwable.class));
+ }
+ }
+ }
+
+ @Test
+ /**
+ * Scenario
+ * - Shuffle has progressed enough
+ * - Last event is yet to arrive
+ * - Failures start happening after Shuffle has progressed enough
+ * - no of attempts failing does not exceed maxFailedUniqueFetches (5)
+ * - Stalled
+ * Expected result
+ * - Do not throw errors, as Shuffle is yet to receive inputs
+ *
+ */
+ @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS)
+ public void testReducerHealth5() throws IOException {
+ try (MockedStatic<IdUtils> idUtils = Mockito.mockStatic(IdUtils.class)) {
+ ApplicationId appId = ApplicationId.newInstance(9999, 72);
+ ApplicationAttemptId appAttemptId =
ApplicationAttemptId.newInstance(appId, 1);
+ idUtils.when(IdUtils::getApplicationAttemptId).thenReturn(appAttemptId);
+ try (MockedStatic<ShuffleUtils> shuffleUtils =
Mockito.mockStatic(ShuffleUtils.class)) {
+ shuffleUtils.when(() ->
ShuffleUtils.deserializeShuffleProviderMetaData(any())).thenReturn(4);
+
+ long startTime = System.currentTimeMillis() - 500000;
+ Shuffle shuffle = mock(Shuffle.class);
+ final ShuffleSchedulerForTest scheduler = createScheduler(startTime,
320, shuffle);
+
+ int totalProducerNodes = 20;
+
+ //Generate 319 events (last event has not arrived)
+ for (int i = 0; i < 319; i++) {
+ CompositeInputAttemptIdentifier inputAttemptIdentifier =
+ new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1);
+ scheduler.addKnownMapOutput("host" + (i % totalProducerNodes),
+ 10000, i, inputAttemptIdentifier);
+ }
+
+ //318 succeeds
+ for (int i = 0; i < 319; i++) {
+ InputAttemptIdentifier inputAttemptIdentifier = new
InputAttemptIdentifier(i, 0, "attempt_");
+ MapOutput mapOutput = MapOutput
+ .createMemoryMapOutput(inputAttemptIdentifier,
mock(FetchedInputAllocatorOrderedGrouped.class),
+ 100, false);
+ scheduler.copySucceeded(inputAttemptIdentifier, new MapHost("host" +
(i % totalProducerNodes),
+ 10000, i, 1), 100, 200, startTime + (i * 100), mapOutput, false);
+ }
+
+ //1 fails (last fetch)
+ InputAttemptIdentifier inputAttemptIdentifier =
+ new InputAttemptIdentifier(318, 0, "attempt_");
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (318
% totalProducerNodes),
+ 10000, 318, 1), false, true, false);
+
+ //stall the shuffle
+ scheduler.lastProgressTime = System.currentTimeMillis() - 1000000;
+
+ assertEquals(scheduler.remainingMaps.get(), 1);
+
+ //Retry for 3 more times
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (318
% totalProducerNodes),
+ 10000, 318, 1), false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (318
% totalProducerNodes),
+ 10000, 318, 1), false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (318
% totalProducerNodes),
+ 10000, 318, 1), false, true, false);
+
+ //Shuffle has not received the events completely. So do not bail out
yet.
+ verify(shuffle, times(0)).reportException(any(Throwable.class));
+ }
+ }
+ }
+
+
+ @Test
+ /**
+ * Scenario
+ * - Shuffle has NOT progressed enough
+ * - Failures start happening
+ * - no of attempts failing exceed maxFailedUniqueFetches (5)
+ * - Not stalled
+ * Expected result
+ * - Bail out
+ *
+ */
+ @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS)
+ public void testReducerHealth6() throws IOException {
+ try (MockedStatic<IdUtils> idUtils = Mockito.mockStatic(IdUtils.class)) {
+ ApplicationId appId = ApplicationId.newInstance(9999, 72);
+ ApplicationAttemptId appAttemptId =
ApplicationAttemptId.newInstance(appId, 1);
+ idUtils.when(IdUtils::getApplicationAttemptId).thenReturn(appAttemptId);
+ try (MockedStatic<ShuffleUtils> shuffleUtils =
Mockito.mockStatic(ShuffleUtils.class)) {
+ shuffleUtils.when(() ->
ShuffleUtils.deserializeShuffleProviderMetaData(any())).thenReturn(4);
+
+ Configuration conf = new TezConfiguration();
+
conf.setBoolean(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FAILED_CHECK_SINCE_LAST_COMPLETION,
true);
+ testReducerHealth6(conf);
+
+
conf.setBoolean(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FAILED_CHECK_SINCE_LAST_COMPLETION,
false);
+ testReducerHealth6(conf);
+ }
+ }
+
+ }
+
+ public void testReducerHealth6(Configuration conf) throws IOException {
+ long startTime = System.currentTimeMillis() - 500000;
+ Shuffle shuffle = mock(Shuffle.class);
+ final ShuffleSchedulerForTest scheduler = createScheduler(startTime, 320,
shuffle, conf);
+
+ int totalProducerNodes = 20;
+
+ //Generate 320 events (last event has not arrived)
+ for (int i = 0; i < 320; i++) {
+ CompositeInputAttemptIdentifier inputAttemptIdentifier =
+ new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1);
+ scheduler.addKnownMapOutput("host" + (i % totalProducerNodes),
+ 10000, i, inputAttemptIdentifier);
+ }
+
+ //10 succeeds
+ for (int i = 0; i < 10; i++) {
+ InputAttemptIdentifier inputAttemptIdentifier =
+ new InputAttemptIdentifier(i, 0, "attempt_");
+ MapOutput mapOutput = MapOutput
+ .createMemoryMapOutput(inputAttemptIdentifier,
mock(FetchedInputAllocatorOrderedGrouped.class),
+ 100, false);
+ scheduler.copySucceeded(inputAttemptIdentifier, new MapHost("host" + (i
% totalProducerNodes),
+ 10000, i, 1), 100, 200, startTime + (i * 100), mapOutput, false);
+ }
+
+ //5 fetches fail once
+ for (int i = 10; i < 15; i++) {
+ InputAttemptIdentifier inputAttemptIdentifier = new
InputAttemptIdentifier(i, 0, "attempt_");
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i %
totalProducerNodes),
+ 10000, i, 1), false, true, false);
+ }
+
+ assertTrue(scheduler.failureCounts.size() >= 5);
+ assertEquals(scheduler.remainingMaps.get(), 310);
+
+ //Do not bail out (number of failures is just 5)
+ verify(scheduler.reporter, times(0)).reportException(any(Throwable.class));
+
+ //5 fetches fail repeatedly
+ for (int i = 10; i < 15; i++) {
+ InputAttemptIdentifier inputAttemptIdentifier = new
InputAttemptIdentifier(i, 0, "attempt_");
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i %
totalProducerNodes),
+ 10000, i, 1), false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier, new MapHost("host" + (i %
totalProducerNodes),
+ 10000, i, 1), false, true, false);
+ }
+
+ boolean checkFailedFetchSinceLastCompletion = conf.getBoolean(
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FAILED_CHECK_SINCE_LAST_COMPLETION,
+
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FAILED_CHECK_SINCE_LAST_COMPLETION_DEFAULT);
+ if (checkFailedFetchSinceLastCompletion) {
+ // Now bail out, as Shuffle has crossed the
+ // failedShufflesSinceLastCompletion limits. (even
+ // though reducerHeathly is
+ verify(shuffle, atLeast(1)).reportException(any(Throwable.class));
+ } else {
+ //Do not bail out yet.
+ verify(shuffle, atLeast(0)).reportException(any(Throwable.class));
+ }
+
+ }
+
+ @Test
+ /**
+ * Scenario
+ * - reducer has not progressed enough
+ * - fetch fails >
+ * TEZ_RUNTIME_SHUFFLE_ACCEPTABLE_HOST_FETCH_FAILURE_FRACTION
+ * Expected result
+ * - fail the reducer
+ */
+ @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS)
+ public void testReducerHealth7() throws IOException {
+ try (MockedStatic<IdUtils> idUtils = Mockito.mockStatic(IdUtils.class)) {
+ ApplicationId appId = ApplicationId.newInstance(9999, 72);
+ ApplicationAttemptId appAttemptId =
ApplicationAttemptId.newInstance(appId, 1);
+ idUtils.when(IdUtils::getApplicationAttemptId).thenReturn(appAttemptId);
+ try (MockedStatic<ShuffleUtils> shuffleUtils =
Mockito.mockStatic(ShuffleUtils.class)) {
+ shuffleUtils.when(() ->
ShuffleUtils.deserializeShuffleProviderMetaData(any())).thenReturn(4);
+
+ long startTime = System.currentTimeMillis() - 500000;
+ Shuffle shuffle = mock(Shuffle.class);
+ final ShuffleSchedulerForTest scheduler = createScheduler(startTime,
320, shuffle);
+
+ int totalProducerNodes = 20;
+
+ //Generate 320 events
+ for (int i = 0; i < 320; i++) {
+ CompositeInputAttemptIdentifier inputAttemptIdentifier =
+ new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1);
+ scheduler.addKnownMapOutput("host" + (i % totalProducerNodes),
10000, i,
+ inputAttemptIdentifier);
+ }
+
+ //100 succeeds
+ for (int i = 0; i < 100; i++) {
+ InputAttemptIdentifier inputAttemptIdentifier = new
InputAttemptIdentifier(i, 0, "attempt_");
+ MapOutput mapOutput = MapOutput
+ .createMemoryMapOutput(inputAttemptIdentifier,
mock(FetchedInputAllocatorOrderedGrouped.class),
+ 100, false);
+ scheduler.copySucceeded(inputAttemptIdentifier,
+ new MapHost("host" + (i % totalProducerNodes), 10000, i, 1),
+ 100, 200, startTime + (i * 100), mapOutput, false);
+ }
+
+ //99 fails
+ for (int i = 100; i < 199; i++) {
+ InputAttemptIdentifier inputAttemptIdentifier =
+ new InputAttemptIdentifier(i, 0, "attempt_");
+ scheduler.copyFailed(inputAttemptIdentifier,
+ new MapHost("host" + (i % totalProducerNodes), 10000, i, 1),
+ false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier,
+ new MapHost("host" + (i % totalProducerNodes), 10000, i, 1),
+ false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier,
+ new MapHost("host" + (i % totalProducerNodes), 10000, i, 1),
+ false, true, false);
+ scheduler.copyFailed(inputAttemptIdentifier,
+ new MapHost("host" + (i % totalProducerNodes), 10000, i, 1),
+ false, true, false);
+ }
+
+ verify(shuffle, atLeast(1)).reportException(any(Throwable.class));
+ }
+ }
+ }
+
+ private ShuffleSchedulerForTest createScheduler(long startTime, int
+ numInputs, Shuffle shuffle, Configuration conf) throws IOException {
+ InputContext inputContext = createTezInputContext();
+ MergeManager mergeManager = mock(MergeManager.class);
+
+ final ShuffleSchedulerForTest scheduler =
+ new ShuffleSchedulerForTest(inputContext, conf, numInputs, shuffle,
mergeManager,
+ mergeManager,startTime, null, false, 0, "srcName");
+ return scheduler;
+ }
+
+ private ShuffleSchedulerForTest createScheduler(long startTime, int
numInputs, Shuffle shuffle)
+ throws IOException {
+ return createScheduler(startTime, numInputs, shuffle, new
+ TezConfiguration());
+ }
+
+ @Test
+ @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS)
+ public void testPenalty() throws IOException, InterruptedException {
+ try (MockedStatic<IdUtils> idUtils = Mockito.mockStatic(IdUtils.class)) {
+ ApplicationId appId = ApplicationId.newInstance(9999, 72);
+ ApplicationAttemptId appAttemptId =
ApplicationAttemptId.newInstance(appId, 1);
+ idUtils.when(IdUtils::getApplicationAttemptId).thenReturn(appAttemptId);
+ try (MockedStatic<ShuffleUtils> shuffleUtils =
Mockito.mockStatic(ShuffleUtils.class)) {
+ shuffleUtils.when(() ->
ShuffleUtils.deserializeShuffleProviderMetaData(any())).thenReturn(4);
+
+ long startTime = System.currentTimeMillis();
+ Shuffle shuffle = mock(Shuffle.class);
+ final ShuffleSchedulerForTest scheduler = createScheduler(startTime,
1, shuffle);
+
+ CompositeInputAttemptIdentifier inputAttemptIdentifier =
+ new CompositeInputAttemptIdentifier(0, 0, "attempt_", 1);
+ scheduler.addKnownMapOutput("host0", 10000, 0, inputAttemptIdentifier);
+
+ assertTrue(scheduler.pendingHosts.size() == 1);
+ assertTrue(scheduler.pendingHosts.iterator().next().getState() ==
MapHost.State.PENDING);
+ MapHost mapHost = scheduler.pendingHosts.iterator().next();
+
+ //Fails to pull from host0. host0 should be added to penalties
+ scheduler.copyFailed(inputAttemptIdentifier, mapHost, false, true,
false);
+
+ //Should not get host, as it is added to penalty loop
+ MapHost host = scheduler.getHost();
+ assertFalse((host.getHost() + ":" + host.getPort() + ":"
+ + host.getPartitionId()).equalsIgnoreCase("host0:10000"));
+
+
+ //Refree thread would release it after INITIAL_PENALTY timeout
+ Thread.sleep(ShuffleScheduler.INITIAL_PENALTY + 1000);
+ host = scheduler.getHost();
+ assertFalse((host.getHost() + ":" + host.getPort() + ":"
+ + host.getPartitionId()).equalsIgnoreCase("host0:10000"));
+ }
+ }
+ }
+
+ @Test
+ @Timeout(value = 20000, unit = TimeUnit.MILLISECONDS)
+ public void testProgressDuringGetHostWait() throws IOException,
InterruptedException {
+ try (MockedStatic<IdUtils> idUtils = Mockito.mockStatic(IdUtils.class)) {
+ ApplicationId appId = ApplicationId.newInstance(9999, 72);
+ ApplicationAttemptId appAttemptId =
ApplicationAttemptId.newInstance(appId, 1);
+ idUtils.when(IdUtils::getApplicationAttemptId).thenReturn(appAttemptId);
+ try (MockedStatic<ShuffleUtils> shuffleUtils =
Mockito.mockStatic(ShuffleUtils.class)) {
+ shuffleUtils.when(() ->
ShuffleUtils.deserializeShuffleProviderMetaData(any())).thenReturn(4);
+
+ long startTime = System.currentTimeMillis();
+ Configuration conf = new TezConfiguration();
+ Shuffle shuffle = mock(Shuffle.class);
+ final ShuffleSchedulerForTest scheduler = createScheduler(startTime,
1, shuffle, conf);
+ Thread schedulerGetHostThread = new Thread(new Runnable() {
+ @Override
+ public void run() {
+ try {
+ scheduler.getHost();
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+ });
+ schedulerGetHostThread.start();
+ Thread.currentThread().sleep(1000 * 3 + 1000);
+ schedulerGetHostThread.interrupt();
+ verify(scheduler.inputContext, atLeast(3)).notifyProgress();
+ }
+ }
+ }
+
+ @Test
+ @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
+ public void testShutdownWithInterrupt() throws Exception {
+ try (MockedStatic<IdUtils> idUtils = Mockito.mockStatic(IdUtils.class)) {
+ ApplicationId appId = ApplicationId.newInstance(9999, 72);
+ ApplicationAttemptId appAttemptId =
ApplicationAttemptId.newInstance(appId, 1);
+ idUtils.when(IdUtils::getApplicationAttemptId).thenReturn(appAttemptId);
+ try (MockedStatic<ShuffleUtils> shuffleUtils =
Mockito.mockStatic(ShuffleUtils.class)) {
+ shuffleUtils.when(() ->
ShuffleUtils.deserializeShuffleProviderMetaData(any())).thenReturn(4);
+
+ InputContext inputContext = createTezInputContext();
+ Configuration conf = new TezConfiguration();
+ int numInputs = 10;
+ Shuffle shuffle = mock(Shuffle.class);
+ MergeManager mergeManager = mock(MergeManager.class);
+
+ final ShuffleSchedulerForTest scheduler =
+ new ShuffleSchedulerForTest(inputContext, conf, numInputs,
shuffle, mergeManager,
+ mergeManager,
+ System.currentTimeMillis(), null, false, 0, "srcName");
+
+ ExecutorService executor = Executors.newFixedThreadPool(1);
+
+ Future<Void> executorFuture = executor.submit(new Callable<Void>() {
+ @Override
+ public Void call() throws Exception {
+ scheduler.start();
+ return null;
+ }
+ });
+
+ InputAttemptIdentifier[] identifiers = new
InputAttemptIdentifier[numInputs];
+
+ for (int i = 0; i < numInputs; i++) {
+ CompositeInputAttemptIdentifier inputAttemptIdentifier =
+ new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1);
+ scheduler.addKnownMapOutput("host" + i, 10000, 1,
inputAttemptIdentifier);
+ identifiers[i] = inputAttemptIdentifier;
+ }
+
+ MapHost[] mapHosts = new MapHost[numInputs];
+ int count = 0;
+ for (MapHost mh : scheduler.mapLocations.values()) {
+ mapHosts[count++] = mh;
+ }
+
+ // Copy succeeded for 1 less host
+ for (int i = 0; i < numInputs - 1; i++) {
+ MapOutput mapOutput = MapOutput
+ .createMemoryMapOutput(identifiers[i],
mock(FetchedInputAllocatorOrderedGrouped.class),
+ 100, false);
+ scheduler.copySucceeded(identifiers[i], mapHosts[i], 20, 25, 100,
mapOutput, false);
+ scheduler.freeHost(mapHosts[i]);
+ }
+
+ try {
+ // Close the scheduler on different thread to trigger interrupt
+ Thread thread = new Thread(new Runnable() {
+ @Override
+ public void run() {
+ scheduler.close();
+ }
+ });
+ thread.start();
+ thread.join();
+ } finally {
+ assertTrue(scheduler.hasFetcherExecutorStopped());
+ executor.shutdownNow();
+ }
+ }
+ }
+ }
+
+
+ private InputContext createTezInputContext() throws IOException {
+ ApplicationId applicationId = ApplicationId.newInstance(1, 1);
+ InputContext inputContext = mock(InputContext.class);
+ doReturn(applicationId).when(inputContext).getApplicationId();
+ doReturn("sourceVertex").when(inputContext).getSourceVertexName();
+ when(inputContext.getCounters()).thenReturn(new TezCounters());
+ ExecutionContext executionContext = new ExecutionContextImpl("localhost");
+ doReturn(executionContext).when(inputContext).getExecutionContext();
+ ByteBuffer shuffleBuffer = ByteBuffer.allocate(4).putInt(0, 4);
+
doReturn(shuffleBuffer).when(inputContext).getServiceProviderMetaData(anyString());
+ Token<JobTokenIdentifier>
+ sessionToken = new Token<JobTokenIdentifier>(new
JobTokenIdentifier(new Text("text")),
+ new JobTokenSecretManager());
+ ByteBuffer tokenBuffer = TezCommonUtils.serializeServiceData(sessionToken);
+
doReturn(tokenBuffer).when(inputContext).getServiceConsumerMetaData(anyString());
+ return inputContext;
+ }
+
+ private static class ShuffleSchedulerForTest extends RssShuffleScheduler {
+
+ private final AtomicInteger numFetchersCreated = new AtomicInteger(0);
+ private final boolean fetcherShouldWait;
+ private final ExceptionReporter reporter;
+ private final InputContext inputContext;
+
+ ShuffleSchedulerForTest(InputContext inputContext, Configuration conf, int
numberOfInputs,
+ Shuffle shuffle,
+ MergeManager mergeManager,
+ FetchedInputAllocatorOrderedGrouped allocator, long startTime,
+ CompressionCodec codec,
+ boolean ifileReadAhead, int ifileReadAheadLength,
+ String srcNameTrimmed) throws IOException {
+ this(inputContext, conf, numberOfInputs, shuffle, mergeManager,
allocator, startTime, codec,
+ ifileReadAhead, ifileReadAheadLength, srcNameTrimmed, false);
+ }
+
+ ShuffleSchedulerForTest(InputContext inputContext, Configuration conf,
+ int numberOfInputs,
+ Shuffle shuffle,
+ MergeManager mergeManager,
+ FetchedInputAllocatorOrderedGrouped allocator, long startTime,
+ CompressionCodec codec,
+ boolean ifileReadAhead, int ifileReadAheadLength,
+ String srcNameTrimmed, boolean fetcherShouldWait) throws
IOException {
+ super(inputContext, conf, numberOfInputs, shuffle, mergeManager,
allocator, startTime, codec,
+ ifileReadAhead, ifileReadAheadLength, srcNameTrimmed);
+ this.fetcherShouldWait = fetcherShouldWait;
+ this.reporter = shuffle;
+ this.inputContext = inputContext;
+ }
+
+ @Override
+ FetcherOrderedGrouped constructFetcherForHost(MapHost mapHost) {
+ numFetchersCreated.incrementAndGet();
+ FetcherOrderedGrouped mockFetcher = mock(FetcherOrderedGrouped.class);
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocation) throws Throwable {
+ if (fetcherShouldWait) {
+ Thread.sleep(100000L);
+ }
+ return null;
+ }
+ }).when(mockFetcher).callInternal();
+ return mockFetcher;
+ }
+ }
+}
+