This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new fbc0b223 [#854][FOLLOWUP] feat(tez): Add RssShuffleManager to run and 
manager shuffle work (#947)
fbc0b223 is described below

commit fbc0b223eeb6a403631e3e9000ed739dde4d229c
Author: Qing <[email protected]>
AuthorDate: Wed Jun 21 15:55:36 2023 +0800

    [#854][FOLLOWUP] feat(tez): Add RssShuffleManager to run and manager 
shuffle work (#947)
    
    ### What changes were proposed in this pull request?
    
    Add RssShuffleManager to run and manager shuffle work
    
    ### Why are the changes needed?
    
    
    Fix: https://github.com/apache/incubator-uniffle/issues/854
    
    ### Does this PR introduce _any_ user-facing change?
    
    
    No.
    
    ### How was this patch tested?
    
    unit test
---
 .../common/shuffle/impl/RssShuffleManager.java     | 1238 ++++++++++++++++++++
 .../common/shuffle/impl/RssTezFetcherTask.java     |  205 ++++
 .../common/shuffle/impl/RssShuffleManagerTest.java |  341 ++++++
 3 files changed, 1784 insertions(+)

diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
new file mode 100644
index 00000000..1c53f8c0
--- /dev/null
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
@@ -0,0 +1,1238 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+import java.text.DecimalFormat;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.BitSet;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.ReentrantLock;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Objects;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.LocalDirAllocator;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.RawLocalFileSystem;
+import org.apache.hadoop.io.compress.CompressionCodec;
+import org.apache.hadoop.util.Time;
+import org.apache.tez.common.CallableWithNdc;
+import org.apache.tez.common.InputContextUtils;
+import org.apache.tez.common.TezRuntimeFrameworkConfigs;
+import org.apache.tez.common.TezUtilsInternal;
+import org.apache.tez.common.UmbilicalUtils;
+import org.apache.tez.common.counters.TaskCounter;
+import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.TezUncheckedException;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.http.HttpConnectionParams;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.api.TaskFailureType;
+import org.apache.tez.runtime.api.events.InputReadErrorEvent;
+import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
+import org.apache.tez.runtime.library.common.CompositeInputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.TezRuntimeUtils;
+import org.apache.tez.runtime.library.common.shuffle.FetchResult;
+import org.apache.tez.runtime.library.common.shuffle.FetchedInput;
+import org.apache.tez.runtime.library.common.shuffle.FetchedInputAllocator;
+import org.apache.tez.runtime.library.common.shuffle.Fetcher;
+import org.apache.tez.runtime.library.common.shuffle.Fetcher.FetcherBuilder;
+import org.apache.tez.runtime.library.common.shuffle.HostPort;
+import org.apache.tez.runtime.library.common.shuffle.InputHost;
+import 
org.apache.tez.runtime.library.common.shuffle.InputHost.PartitionToInputs;
+import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
+import 
org.apache.tez.runtime.library.common.shuffle.ShuffleUtils.FetchStatsLogger;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+// This only knows how to deal with a single srcIndex for a given targetIndex.
+// In case the src task generates multiple outputs for the same target Index
+// (multiple src-indices), modifications will be required.
+public class RssShuffleManager extends ShuffleManager {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(RssShuffleManager.class);
+  private static final Logger LOG_FETCH = 
LoggerFactory.getLogger(LOG.getName() + ".fetch");
+  private static final FetchStatsLogger fetchStatsLogger = new 
FetchStatsLogger(LOG_FETCH, LOG);
+
+  private final InputContext inputContext;
+  private final int numInputs;
+
+  private final DecimalFormat mbpsFormat = new DecimalFormat("0.00");
+
+  private final FetchedInputAllocator inputManager;
+
+  @VisibleForTesting
+  final ListeningExecutorService fetcherExecutor;
+
+  /**
+   * Executor for ReportCallable.
+   */
+  private ExecutorService reporterExecutor;
+
+  /**
+   * Lock to sync failedEvents.
+   */
+  private final ReentrantLock reportLock = new ReentrantLock();
+
+  /**
+   * Condition to wake up the thread notifying when events fail.
+   */
+  private final Condition reportCondition = reportLock.newCondition();
+
+  /**
+   * Events reporting fetcher failed.
+   */
+  private final HashMap<InputReadErrorEvent, Integer> failedEvents = new 
HashMap<>();
+
+  private final ListeningExecutorService schedulerExecutor;
+  private final RssRunShuffleCallable rssSchedulerCallable;
+
+  private final BlockingQueue<FetchedInput> completedInputs;
+  private final AtomicBoolean inputReadyNotificationSent = new 
AtomicBoolean(false);
+  @VisibleForTesting
+  final BitSet completedInputSet;
+  private final ConcurrentMap<HostPort, InputHost> knownSrcHosts;
+  private final BlockingQueue<InputHost> pendingHosts;
+  private final Set<InputAttemptIdentifier> obsoletedInputs;
+  private Set<RssTezFetcherTask> rssRunningFetchers;
+
+  private final AtomicInteger numCompletedInputs = new AtomicInteger(0);
+  private final AtomicInteger numFetchedSpills = new AtomicInteger(0);
+
+  private final long startTime;
+  private long lastProgressTime;
+  private long totalBytesShuffledTillNow;
+
+  // Required to be held when manipulating pendingHosts
+  private final ReentrantLock lock = new ReentrantLock();
+  private final Condition wakeLoop = lock.newCondition();
+
+  private final int numFetchers;
+  private final boolean asyncHttp;
+
+  // Parameters required by Fetchers
+  private final CompressionCodec codec;
+  private final Configuration conf;
+  private final boolean localDiskFetchEnabled;
+  private final boolean sharedFetchEnabled;
+  private final boolean verifyDiskChecksum;
+  private final boolean compositeFetch;
+
+  private final int ifileBufferSize;
+  private final boolean ifileReadAhead;
+  private final int ifileReadAheadLength;
+
+  /**
+   * Holds the time to wait for failures to batch them and send less events.
+   */
+  private final int maxTimeToWaitForReportMillis;
+
+  private final String srcNameTrimmed;
+
+  private final int maxTaskOutputAtOnce;
+
+  private final AtomicBoolean isShutdown = new AtomicBoolean(false);
+
+  private final TezCounter shuffledInputsCounter;
+  private final TezCounter failedShufflesCounter;
+  private final TezCounter bytesShuffledCounter;
+  private final TezCounter decompressedDataSizeCounter;
+  private final TezCounter bytesShuffledToDiskCounter;
+  private final TezCounter bytesShuffledToMemCounter;
+  private final TezCounter bytesShuffledDirectDiskCounter;
+
+  private volatile Throwable shuffleError;
+  private final HttpConnectionParams httpConnectionParams;
+
+  private final LocalDirAllocator localDirAllocator;
+  private final RawLocalFileSystem localFs;
+  private final Path[] localDisks;
+  private final String localhostName;
+  private final int shufflePort;
+
+  private final TezCounter shufflePhaseTime;
+  private final TezCounter firstEventReceived;
+  private final TezCounter lastEventReceived;
+
+  //To track shuffleInfo events when finalMerge is disabled OR pipelined 
shuffle is enabled in source.
+  @VisibleForTesting
+  final Map<Integer, ShuffleEventInfo> shuffleInfoEventsMap;
+
+  private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+  private final Set<Integer> successRssPartitionSet = new HashSet<>();
+  private final Set<Integer> runningRssPartitionMap = new HashSet<>();
+  private final Set<Integer> allRssPartition = Sets.newConcurrentHashSet();
+  private final BlockingQueue<Integer> pendingPartition = new 
LinkedBlockingQueue<>();
+  Map<Integer, List<InputAttemptIdentifier>> partitionToInput = new 
HashMap<>();
+  private final Map<Integer, Roaring64NavigableMap> rssAllBlockIdBitmapMap = 
new ConcurrentHashMap<>();
+  private final Map<Integer, Roaring64NavigableMap> rssSuccessBlockIdBitmapMap 
= new ConcurrentHashMap<>();
+  private final AtomicInteger numNoDataInput = new AtomicInteger(0);
+  private final AtomicInteger numWithDataInput = new AtomicInteger(0);
+
+  public RssShuffleManager(InputContext inputContext, Configuration conf, int 
numInputs,
+          int bufferSize, boolean ifileReadAheadEnabled, int 
ifileReadAheadLength,
+          CompressionCodec codec, FetchedInputAllocator inputAllocator) throws 
IOException {
+    super(inputContext, conf, numInputs, bufferSize, ifileReadAheadEnabled, 
ifileReadAheadLength, codec,
+        inputAllocator);
+    this.inputContext = inputContext;
+    this.conf = conf;
+    this.numInputs = numInputs;
+
+    this.shuffledInputsCounter = 
inputContext.getCounters().findCounter(TaskCounter.NUM_SHUFFLED_INPUTS);
+    this.failedShufflesCounter = 
inputContext.getCounters().findCounter(TaskCounter.NUM_FAILED_SHUFFLE_INPUTS);
+    this.bytesShuffledCounter = 
inputContext.getCounters().findCounter(TaskCounter.SHUFFLE_BYTES);
+    this.decompressedDataSizeCounter = 
inputContext.getCounters().findCounter(TaskCounter.SHUFFLE_BYTES_DECOMPRESSED);
+    this.bytesShuffledToDiskCounter = 
inputContext.getCounters().findCounter(TaskCounter.SHUFFLE_BYTES_TO_DISK);
+    this.bytesShuffledToMemCounter = 
inputContext.getCounters().findCounter(TaskCounter.SHUFFLE_BYTES_TO_MEM);
+    this.bytesShuffledDirectDiskCounter = 
inputContext.getCounters().findCounter(TaskCounter.SHUFFLE_BYTES_DISK_DIRECT);
+
+    this.ifileBufferSize = bufferSize;
+    this.ifileReadAhead = ifileReadAheadEnabled;
+    this.ifileReadAheadLength = ifileReadAheadLength;
+    this.codec = codec;
+    this.inputManager = inputAllocator;
+    this.localDiskFetchEnabled = 
conf.getBoolean(TezRuntimeConfiguration.TEZ_RUNTIME_OPTIMIZE_LOCAL_FETCH,
+        TezRuntimeConfiguration.TEZ_RUNTIME_OPTIMIZE_LOCAL_FETCH_DEFAULT);
+    this.sharedFetchEnabled = 
conf.getBoolean(TezRuntimeConfiguration.TEZ_RUNTIME_OPTIMIZE_SHARED_FETCH,
+        TezRuntimeConfiguration.TEZ_RUNTIME_OPTIMIZE_SHARED_FETCH_DEFAULT);
+    this.verifyDiskChecksum = conf.getBoolean(
+        TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_VERIFY_DISK_CHECKSUM,
+        
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_VERIFY_DISK_CHECKSUM_DEFAULT);
+    this.maxTimeToWaitForReportMillis = 1;
+
+    this.shufflePhaseTime = 
inputContext.getCounters().findCounter(TaskCounter.SHUFFLE_PHASE_TIME);
+    this.firstEventReceived = 
inputContext.getCounters().findCounter(TaskCounter.FIRST_EVENT_RECEIVED);
+    this.lastEventReceived = 
inputContext.getCounters().findCounter(TaskCounter.LAST_EVENT_RECEIVED);
+    this.compositeFetch = ShuffleUtils.isTezShuffleHandler(conf);
+
+    this.srcNameTrimmed = 
TezUtilsInternal.cleanVertexName(inputContext.getSourceVertexName());
+
+    completedInputSet = new BitSet(numInputs);
+    /**
+     * In case of pipelined shuffle, it is possible to get multiple 
FetchedInput per attempt.
+     * We do not know upfront the number of spills from source.
+     */
+    completedInputs = new LinkedBlockingDeque<>();
+    knownSrcHosts = new ConcurrentHashMap<>();
+    pendingHosts = new LinkedBlockingQueue<>();
+    obsoletedInputs = Collections.newSetFromMap(new ConcurrentHashMap<>());
+    rssRunningFetchers = Collections.newSetFromMap(new ConcurrentHashMap<>());
+
+    int maxConfiguredFetchers =
+        conf.getInt(
+            TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_PARALLEL_COPIES,
+            
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_PARALLEL_COPIES_DEFAULT);
+
+    this.numFetchers = Math.min(maxConfiguredFetchers, numInputs);
+
+    final ExecutorService fetcherRawExecutor;
+    if 
(conf.getBoolean(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCHER_USE_SHARED_POOL,
+        
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCHER_USE_SHARED_POOL_DEFAULT)) {
+      fetcherRawExecutor = 
inputContext.createTezFrameworkExecutorService(numFetchers,
+          "Fetcher_B {" + srcNameTrimmed + "} #%d");
+    } else {
+      fetcherRawExecutor = Executors.newFixedThreadPool(numFetchers, new 
ThreadFactoryBuilder()
+          .setDaemon(true).setNameFormat("Fetcher_B {" + srcNameTrimmed + "} 
#%d").build());
+    }
+    this.fetcherExecutor = 
MoreExecutors.listeningDecorator(fetcherRawExecutor);
+
+    ExecutorService schedulerRawExecutor = Executors.newFixedThreadPool(1, new 
ThreadFactoryBuilder()
+        .setDaemon(true).setNameFormat("ShuffleRunner {" + srcNameTrimmed + 
"}").build());
+    this.schedulerExecutor = 
MoreExecutors.listeningDecorator(schedulerRawExecutor);
+    this.rssSchedulerCallable = new RssRunShuffleCallable(conf);
+
+    this.startTime = System.currentTimeMillis();
+    this.lastProgressTime = startTime;
+
+    this.asyncHttp = 
conf.getBoolean(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_USE_ASYNC_HTTP, 
false);
+    httpConnectionParams = ShuffleUtils.getHttpConnectionParams(conf);
+
+    this.localFs = (RawLocalFileSystem) FileSystem.getLocal(conf).getRaw();
+
+    this.localDirAllocator = new LocalDirAllocator(
+        TezRuntimeFrameworkConfigs.LOCAL_DIRS);
+
+    this.localDisks = Iterables.toArray(
+        localDirAllocator.getAllLocalPathsToRead(".", conf), Path.class);
+    this.localhostName = inputContext.getExecutionContext().getHostName();
+
+    String auxiliaryService = 
conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID,
+        TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT);
+    final ByteBuffer shuffleMetaData =
+        inputContext.getServiceProviderMetaData(auxiliaryService);
+    this.shufflePort = 
ShuffleUtils.deserializeShuffleProviderMetaData(shuffleMetaData);
+
+    /**
+     * Setting to very high val can lead to Http 400 error. Cap it to 75; 
every attempt id would
+     * be approximately 48 bytes; 48 * 75 = 3600 which should give some room 
for other info in URL.
+     */
+    this.maxTaskOutputAtOnce = Math.max(1, Math.min(75, conf.getInt(
+        
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_MAX_TASK_OUTPUT_AT_ONCE,
+        
TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCH_MAX_TASK_OUTPUT_AT_ONCE_DEFAULT)));
+
+    if (null != this.localDisks) {
+      Arrays.sort(this.localDisks);
+    }
+
+    shuffleInfoEventsMap = new ConcurrentHashMap<>();
+
+    LOG.info(srcNameTrimmed + ": numInputs=" + numInputs + ", 
compressionCodec="
+        + (codec == null ? "NoCompressionCodec" : codec.getClass().getName()) 
+ ", numFetchers="
+        + numFetchers + ", ifileBufferSize=" + ifileBufferSize + ", 
ifileReadAheadEnabled="
+        + ifileReadAhead + ", ifileReadAheadLength=" + ifileReadAheadLength + 
", "
+        + "localDiskFetchEnabled=" + localDiskFetchEnabled + ", "
+        + "sharedFetchEnabled=" + sharedFetchEnabled + ", "
+        + httpConnectionParams.toString() + ", maxTaskOutputAtOnce=" + 
maxTaskOutputAtOnce);
+  }
+
+  @Override
+  public void run() throws IOException {
+    int shuffleId = InputContextUtils.computeShuffleId(this.inputContext);
+    TezTaskAttemptID tezTaskAttemptId = 
InputContextUtils.getTezTaskAttemptID(this.inputContext);
+    this.partitionToServers = UmbilicalUtils.requestShuffleServer(
+          this.inputContext.getApplicationId(), this.conf, tezTaskAttemptId, 
shuffleId);
+
+
+    Preconditions.checkState(inputManager != null, "InputManager must be 
configured");
+    if (maxTimeToWaitForReportMillis > 0) {
+      reporterExecutor = Executors.newSingleThreadExecutor(
+          new ThreadFactoryBuilder().setDaemon(true)
+              .setNameFormat("ShuffleRunner {" + srcNameTrimmed + "}")
+              .build());
+      Future reporterFuture = reporterExecutor.submit(new ReporterCallable());
+    }
+
+    ListenableFuture<Void> runShuffleFuture = 
schedulerExecutor.submit(rssSchedulerCallable);
+    Futures.addCallback(runShuffleFuture, new SchedulerFutureCallback(), 
MoreExecutors.directExecutor());
+    // Shutdown this executor once this task, and the callback complete.
+    schedulerExecutor.shutdown();
+  }
+
+  private class ReporterCallable extends CallableWithNdc<Void> {
+    /**
+     * Measures if the batching interval has ended.
+     */
+
+    ReporterCallable() {
+    }
+
+    @Override
+    protected Void callInternal() throws Exception {
+      long nextReport = 0;
+      while (!isShutdown.get()) {
+        try {
+          reportLock.lock();
+          while (failedEvents.isEmpty()) {
+            boolean signaled = 
reportCondition.await(maxTimeToWaitForReportMillis,
+                TimeUnit.MILLISECONDS);
+          }
+
+          long currentTime = Time.monotonicNow();;
+          if (currentTime > nextReport) {
+            if (failedEvents.size() > 0) {
+              List<Event> failedEventsToSend = Lists.newArrayListWithCapacity(
+                  failedEvents.size());
+              for (InputReadErrorEvent key : failedEvents.keySet()) {
+                failedEventsToSend.add(InputReadErrorEvent
+                    .create(key.getDiagnostics(), key.getIndex(),
+                        key.getVersion()));
+              }
+              inputContext.sendEvents(failedEventsToSend);
+              failedEvents.clear();
+              nextReport = currentTime + maxTimeToWaitForReportMillis;
+            }
+          }
+        } finally {
+          reportLock.unlock();
+        }
+      }
+      return null;
+    }
+  }
+
+
+  private boolean isAllInputFetched() {
+    LOG.info("Check isAllInputFetched, numNoDataInput:{}, 
numWithDataInput:{},numInputs:{},  " 
+            + "successRssPartitionSet:{},  allRssPartition:{}.",
+        numNoDataInput, numWithDataInput, numInputs, successRssPartitionSet, 
allRssPartition);
+    return (numNoDataInput.get() + numWithDataInput.get() >= numInputs)
+        && (successRssPartitionSet.size() >= allRssPartition.size());
+  }
+
+  private boolean isAllInputAdded() {
+    LOG.info("Check isAllInputAdded, numNoDataInput:{}, 
numWithDataInput:{},numInputs:{},  "
+            + "successRssPartitionSet:{}, allRssPartition:{}.",
+        numNoDataInput, numWithDataInput, numInputs, successRssPartitionSet, 
allRssPartition);
+    return numNoDataInput.get() + numWithDataInput.get() >= numInputs;
+  }
+
+  private class RssRunShuffleCallable extends CallableWithNdc<Void> {
+
+    private final Configuration conf;
+
+    RssRunShuffleCallable(Configuration conf) {
+      this.conf = conf;
+    }
+
+    @Override
+    protected Void callInternal() throws Exception {
+      while (!isShutdown.get() && !isAllInputFetched()) {
+        lock.lock();
+        try {
+          LOG.info("numFetchers:{}, shuffleInfoEventsMap.size:{}, 
numInputs:{}.",
+              numFetchers, shuffleInfoEventsMap.size(), numInputs);
+          while (((rssRunningFetchers.size() >= numFetchers || 
pendingPartition.isEmpty()) && !isAllInputFetched())
+                || !isAllInputAdded()) {
+            LOG.info("isAllInputAdded:{}, rssRunningFetchers:{}, 
numFetchers:{}, pendingPartition:{}, "
+                    + "successRssPartitionSet:{}, allRssPartition:{} ", 
isAllInputAdded(), rssRunningFetchers,
+                numFetchers, pendingPartition, successRssPartitionSet, 
allRssPartition);
+
+            inputContext.notifyProgress();
+            boolean isSignal =  wakeLoop.await(1000, TimeUnit.MILLISECONDS);
+            if (isSignal) {
+              LOG.info("wakeLoop is signal");
+            }
+            if (isShutdown.get()) {
+              LOG.info("is shut down and break");
+              break;
+            }
+          }
+          LOG.info("run out of while, is all inputadded:{}, fetched:{}", 
isAllInputAdded(), isAllInputFetched());
+        } finally {
+          lock.unlock();
+        }
+
+        if (shuffleError != null) {
+          LOG.warn("Shuffle error.", shuffleError);
+          // InputContext has already been informed of a fatal error. Relying 
on
+          // tez to kill the task.
+          break;
+        }
+
+        if (LOG.isDebugEnabled()) {
+          LOG.debug(srcNameTrimmed + ": " + "NumCompletedInputs: " + 
numCompletedInputs);
+        }
+
+        if (!isAllInputFetched() && !isShutdown.get()) {
+          lock.lock();
+          try {
+            LOG.info("numFetchers:{},runningFetchers.size():{}.", numFetchers, 
rssRunningFetchers.size());
+            int maxFetchersToRun = numFetchers - rssRunningFetchers.size();
+            int count = 0;
+            LOG.info("pendingPartition:{}", pendingPartition.peek());
+            while (pendingPartition.peek() != null && !isShutdown.get()) {
+              Integer partition = null;
+              try {
+                partition = pendingPartition.take();
+              } catch (InterruptedException e) {
+                if (isShutdown.get()) {
+                  LOG.info(srcNameTrimmed + ": " + "Interrupted and 
hasBeenShutdown, Breaking out of ShuffleScheduler");
+                  Thread.currentThread().interrupt();
+                  break;
+                } else {
+                  throw e;
+                }
+              }
+
+              if (LOG.isDebugEnabled()) {
+                LOG.debug(srcNameTrimmed + ": " + "Processing pending 
partition: " + partition);
+              }
+
+              if (!isShutdown.get() && 
(!successRssPartitionSet.contains(partition)
+                  && !runningRssPartitionMap.contains(partition))) {
+                runningRssPartitionMap.add(partition);
+                LOG.info("generate RssTezFetcherTask, partition:{}, 
rssWoker:{}, all woker:{}",
+                    partition, partitionToServers.get(partition), 
partitionToServers);
+
+                RssTezFetcherTask fetcher = new 
RssTezFetcherTask(RssShuffleManager.this, inputContext,
+                    conf, inputManager, partition, 
partitionToInput.get(partition),
+                    new 
HashSet<ShuffleServerInfo>(partitionToServers.get(partition)),
+                    rssAllBlockIdBitmapMap, rssSuccessBlockIdBitmapMap, 
numInputs, partitionToServers.size());
+                rssRunningFetchers.add(fetcher);
+                if (isShutdown.get()) {
+                  LOG.info(srcNameTrimmed + ": " + "hasBeenShutdown,"
+                      + "Breaking out of ShuffleScheduler Loop");
+                  break;
+                }
+                ListenableFuture<FetchResult> future = 
fetcherExecutor.submit(fetcher);    // add fetcher task
+                Futures.addCallback(future, new FetchFutureCallback(fetcher), 
MoreExecutors.directExecutor());
+                if (++count >= maxFetchersToRun) {
+                  break;
+                }
+              } else {
+                if (LOG.isDebugEnabled()) {
+                  LOG.debug(srcNameTrimmed + ": " + "Skipping partition: " + 
partition + " since is shutdown");
+                }
+              }
+            }
+          } finally {
+            lock.unlock();
+          }
+        }
+      }
+      LOG.info("RssShuffleManager numInputs:{}", numInputs);
+      shufflePhaseTime.setValue(System.currentTimeMillis() - startTime);
+      LOG.info(srcNameTrimmed + ": " + "Shutting down FetchScheduler, Was 
Interrupted: "
+              + Thread.currentThread().isInterrupted());
+      if (!fetcherExecutor.isShutdown()) {
+        fetcherExecutor.shutdownNow();
+      }
+      return null;
+    }
+  }
+
+
+  private boolean 
validateInputAttemptForPipelinedShuffle(InputAttemptIdentifier input) {
+    //For pipelined shuffle.
+    //TEZ-2132 for error handling. As of now, fail fast if there is a 
different attempt
+    if (input.canRetrieveInputInChunks()) {
+      ShuffleEventInfo eventInfo = 
shuffleInfoEventsMap.get(input.getInputIdentifier());
+      if (eventInfo != null && input.getAttemptNumber() != 
eventInfo.attemptNum) {
+        if (eventInfo.scheduledForDownload || 
!eventInfo.eventsProcessed.isEmpty()) {
+          IOException exception = new IOException("Previous event already got 
scheduled for " 
+              + input + ". Previous attempt's data could have been already 
merged "
+              + "to memory/disk outputs.  Killing (self) this task early."
+              + " currentAttemptNum=" + eventInfo.attemptNum
+              + ", eventsProcessed=" + eventInfo.eventsProcessed
+              + ", scheduledForDownload=" + eventInfo.scheduledForDownload
+              + ", newAttemptNum=" + input.getAttemptNumber());
+          String message = "Killing self as previous attempt data could have 
been consumed";
+          killSelf(exception, message);
+          return false;
+        }
+      }
+    }
+    return true;
+  }
+
+  @Override
+  void killSelf(Exception exception, String message) {
+    LOG.error(message, exception);
+    this.inputContext.killSelf(exception, message);
+  }
+
+  @VisibleForTesting
+  @Override
+  Fetcher constructFetcherForHost(InputHost inputHost, Configuration conf) {
+    Path lockDisk = null;
+
+    if (sharedFetchEnabled) {
+      // pick a single lock disk from the edge name's hashcode + host hashcode
+      final int h = Math.abs(Objects.hashCode(this.srcNameTrimmed, 
inputHost.getHost()));
+      lockDisk = new Path(this.localDisks[h % this.localDisks.length], 
"locks");
+    }
+
+    FetcherBuilder fetcherBuilder = new FetcherBuilder(RssShuffleManager.this,
+        httpConnectionParams, inputManager, inputContext.getApplicationId(), 
inputContext.getDagIdentifier(),
+        null, srcNameTrimmed, conf, localFs, localDirAllocator,
+        lockDisk, localDiskFetchEnabled, sharedFetchEnabled,
+        localhostName, shufflePort, asyncHttp, verifyDiskChecksum, 
compositeFetch);
+
+    if (codec != null) {
+      fetcherBuilder.setCompressionParameters(codec);
+    }
+    fetcherBuilder.setIFileParams(ifileReadAhead, ifileReadAheadLength);
+
+    // Remove obsolete inputs from the list being given to the fetcher. Also
+    // remove from the obsolete list.
+    PartitionToInputs pendingInputsOfOnePartitionRange = inputHost
+        .clearAndGetOnePartitionRange();
+    int includedMaps = 0;
+    for (Iterator<InputAttemptIdentifier> inputIter =
+         pendingInputsOfOnePartitionRange.getInputs().iterator();
+         inputIter.hasNext();) {
+      InputAttemptIdentifier input = inputIter.next();
+
+      //For pipelined shuffle.
+      if (!validateInputAttemptForPipelinedShuffle(input)) {
+        continue;
+      }
+
+      // Avoid adding attempts which have already completed.
+      boolean alreadyCompleted;
+      if (input instanceof CompositeInputAttemptIdentifier) {
+        CompositeInputAttemptIdentifier compositeInput = 
(CompositeInputAttemptIdentifier)input;
+        int nextClearBit = 
completedInputSet.nextClearBit(compositeInput.getInputIdentifier());
+        int maxClearBit = compositeInput.getInputIdentifier() + 
compositeInput.getInputIdentifierCount();
+        alreadyCompleted = nextClearBit > maxClearBit;
+      } else {
+        alreadyCompleted = completedInputSet.get(input.getInputIdentifier());
+      }
+      // Avoid adding attempts which have already completed or have been 
marked as OBSOLETE
+      if (alreadyCompleted || obsoletedInputs.contains(input)) {
+        inputIter.remove();
+        continue;
+      }
+
+      // Check if max threshold is met
+      if (includedMaps >= maxTaskOutputAtOnce) {
+        inputIter.remove();
+        //add to inputHost
+        
inputHost.addKnownInput(pendingInputsOfOnePartitionRange.getPartition(),
+            pendingInputsOfOnePartitionRange.getPartitionCount(), input);
+      } else {
+        includedMaps++;
+      }
+    }
+    if (inputHost.getNumPendingPartitions() > 0) {
+      pendingHosts.add(inputHost); //add it to queue
+    }
+    for (InputAttemptIdentifier input : 
pendingInputsOfOnePartitionRange.getInputs()) {
+      ShuffleEventInfo eventInfo = 
shuffleInfoEventsMap.get(input.getInputIdentifier());
+      if (eventInfo != null) {
+        eventInfo.scheduledForDownload = true;
+      }
+    }
+    fetcherBuilder.assignWork(inputHost.getHost(), inputHost.getPort(),
+        pendingInputsOfOnePartitionRange.getPartition(),
+        pendingInputsOfOnePartitionRange.getPartitionCount(),
+        pendingInputsOfOnePartitionRange.getInputs());
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Created Fetcher for host: " + inputHost.getHost()
+          + ", info: " + inputHost.getAdditionalInfo()
+          + ", with inputs: " + pendingInputsOfOnePartitionRange);
+    }
+    return fetcherBuilder.build();
+  }
+
+
+  /////////////////// Methods for InputEventHandler
+  @Override
+  public void addKnownInput(String hostName, int port,
+                            CompositeInputAttemptIdentifier 
srcAttemptIdentifier, int srcPhysicalIndex) {
+    HostPort identifier = new HostPort(hostName, port);
+    InputHost host = knownSrcHosts.get(identifier);
+    if (host == null) {
+      host = new InputHost(identifier);
+      InputHost old = knownSrcHosts.putIfAbsent(identifier, host);
+      if (old != null) {
+        host = old;
+      }
+    }
+    if (LOG.isDebugEnabled()) {
+      LOG.debug(srcNameTrimmed + ": " + "Adding input: " + 
srcAttemptIdentifier + ", to host: " + host);
+    }
+
+    if (!validateInputAttemptForPipelinedShuffle(srcAttemptIdentifier)) {
+      return;
+    }
+    int inputIdentifier = srcAttemptIdentifier.getInputIdentifier();
+    for (int i = 0; i < srcAttemptIdentifier.getInputIdentifierCount(); i++) {
+      if (shuffleInfoEventsMap.get(inputIdentifier + i) == null) {
+        shuffleInfoEventsMap.put(inputIdentifier + i, new 
ShuffleEventInfo(srcAttemptIdentifier.expand(i)));
+        LOG.info("AddKnownInput, srcAttemptIdentifier:{}, i:{}, expand:{}, 
map:{}",
+            srcAttemptIdentifier, i, srcAttemptIdentifier.expand(i), 
shuffleInfoEventsMap);
+      }
+    }
+
+    host.addKnownInput(srcPhysicalIndex, 
srcAttemptIdentifier.getInputIdentifierCount(), srcAttemptIdentifier);
+    lock.lock();
+    try {
+      boolean added = pendingHosts.offer(host);
+      if (!added) {
+        String errorMessage = "Unable to add host: " + host.getIdentifier() + 
" to pending queue";
+        LOG.error(errorMessage);
+        throw new TezUncheckedException(errorMessage);
+      }
+      wakeLoop.signal();
+    } finally {
+      lock.unlock();
+    }
+
+    LOG.info("AddKnowInput, hostname:{}, port:{}, srcAttemptIdentifier:{}, 
srcPhysicalIndex:{}",
+            hostName, port, srcAttemptIdentifier, srcPhysicalIndex);
+
+    lock.lock();
+    try {
+      for (int i = 0; i < srcAttemptIdentifier.getInputIdentifierCount(); i++) 
{
+        int p = srcPhysicalIndex + i;
+        LOG.info("PartitionToInput, original:{}, add:{},  now:{}",
+                srcAttemptIdentifier, srcAttemptIdentifier.expand(i), 
partitionToInput.get(p));
+        if (!allRssPartition.contains(srcPhysicalIndex + i)) {
+          pendingPartition.add(p);
+        }
+        allRssPartition.add(p);
+        partitionToInput.putIfAbsent(p, new ArrayList<>());
+        partitionToInput.get(p).add(srcAttemptIdentifier);
+        LOG.info("Add partition:{}, after add, now partition:{}", p, 
allRssPartition);
+      }
+
+      numWithDataInput.incrementAndGet();
+      LOG.info("numWithDataInput:{}.", numWithDataInput.get());
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public void addCompletedInputWithNoData(
+      InputAttemptIdentifier srcAttemptIdentifier) {
+    int inputIdentifier = srcAttemptIdentifier.getInputIdentifier();
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("No input data exists for SrcTask: " + inputIdentifier + ". 
Marking as complete.");
+    }
+    lock.lock();
+    try {
+      if (!completedInputSet.get(inputIdentifier)) {
+        NullFetchedInput fetchedInput = new 
NullFetchedInput(srcAttemptIdentifier);
+        if (!srcAttemptIdentifier.canRetrieveInputInChunks()) {
+          registerCompletedInput(fetchedInput);
+        } else {
+          registerCompletedInputForPipelinedShuffle(srcAttemptIdentifier, 
fetchedInput);
+        }
+      }
+      // Awake the loop to check for termination.
+      wakeLoop.signal();
+    } finally {
+      lock.unlock();
+    }
+    numNoDataInput.incrementAndGet();
+    LOG.info("AddCompletedInputWithNoData, numNoDataInput:{}, 
numWithDataInput:{},numInputs:{},  "
+            + "successRssPartitionSet:{}, allRssPartition:{}.",
+        numNoDataInput, numWithDataInput, numInputs, successRssPartitionSet, 
allRssPartition);
+  }
+
+  @Override
+  protected synchronized void updateEventReceivedTime() {
+    long relativeTime = System.currentTimeMillis() - startTime;
+    if (firstEventReceived.getValue() == 0) {
+      firstEventReceived.setValue(relativeTime);
+      lastEventReceived.setValue(relativeTime);
+      return;
+    }
+    lastEventReceived.setValue(relativeTime);
+  }
+
+  @Override
+  void obsoleteKnownInput(InputAttemptIdentifier srcAttemptIdentifier) {
+    obsoletedInputs.add(srcAttemptIdentifier);
+    // NEWTEZ Maybe inform the fetcher about this. For now, this is used 
during the initial fetch list construction.
+  }
+
+  // End of Methods for InputEventHandler
+  // Methods from FetcherCallbackHandler
+
+  /**
+   * Placeholder for tracking shuffle events in case we get multiple spills 
info for the same
+   * attempt.
+   */
+  static class ShuffleEventInfo {
+    BitSet eventsProcessed;
+    int finalEventId = -1; //0 indexed
+    int attemptNum;
+    String id;
+    boolean scheduledForDownload; // whether chunks got scheduled for download
+
+
+    ShuffleEventInfo(InputAttemptIdentifier input) {
+      this.id = input.getInputIdentifier() + "_" + input.getAttemptNumber();
+      this.eventsProcessed = new BitSet();
+      this.attemptNum = input.getAttemptNumber();
+    }
+
+    void spillProcessed(int spillId) {
+      if (finalEventId != -1) {
+        Preconditions.checkState(eventsProcessed.cardinality() <= 
(finalEventId + 1),
+            "Wrong state. eventsProcessed cardinality=" + 
eventsProcessed.cardinality() + " "
+                + "finalEventId=" + finalEventId + ", spillId=" + spillId + ", 
" + toString());
+      }
+      eventsProcessed.set(spillId);
+    }
+
+    void setFinalEventId(int spillId) {
+      finalEventId = spillId;
+    }
+
+    boolean isDone() {
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("finalEventId=" + finalEventId + ", eventsProcessed 
cardinality="
+            + eventsProcessed.cardinality());
+      }
+      return ((finalEventId != -1) && (finalEventId + 1) == 
eventsProcessed.cardinality());
+    }
+
+    @Override
+    public String toString() {
+      return "[eventsProcessed=" + eventsProcessed + ", finalEventId=" + 
finalEventId
+          +  ", id=" + id + ", attemptNum=" + attemptNum
+          + ", scheduledForDownload=" + scheduledForDownload + "]";
+    }
+  }
+
+  @Override
+  public void fetchSucceeded(String host, InputAttemptIdentifier 
srcAttemptIdentifier,
+                             FetchedInput fetchedInput, long fetchedBytes, 
long decompressedLength, long copyDuration)
+      throws IOException {
+    // Count irrespective of whether this is a copy of an already fetched input
+    lock.lock();
+    try {
+      lastProgressTime = System.currentTimeMillis();
+      inputContext.notifyProgress();
+      fetchedInput.commit();
+      fetchStatsLogger.logIndividualFetchComplete(copyDuration,
+          fetchedBytes, decompressedLength, fetchedInput.getType().toString(), 
srcAttemptIdentifier);
+
+      // Processing counters for completed and commit fetches only. Need
+      // additional counters for excessive fetches - which primarily comes
+      // in after speculation or retries.
+      shuffledInputsCounter.increment(1);
+      bytesShuffledCounter.increment(fetchedBytes);
+      if (fetchedInput.getType() == FetchedInput.Type.MEMORY) {
+        bytesShuffledToMemCounter.increment(fetchedBytes);
+      } else if (fetchedInput.getType() == FetchedInput.Type.DISK) {
+        LOG.warn("Rss bytesShuffledToDiskCounter");
+        bytesShuffledToDiskCounter.increment(fetchedBytes);
+      } else if (fetchedInput.getType() == FetchedInput.Type.DISK_DIRECT) {
+        LOG.warn("Rss bytesShuffledDirectDiskCounter");
+        bytesShuffledDirectDiskCounter.increment(fetchedBytes);
+      }
+      decompressedDataSizeCounter.increment(decompressedLength);
+
+      if (!srcAttemptIdentifier.canRetrieveInputInChunks()) {
+        registerCompletedInput(fetchedInput);
+      } else {
+        LOG.warn("Rss registerCompletedInputForPipelinedShuffle");
+        registerCompletedInputForPipelinedShuffle(srcAttemptIdentifier, 
fetchedInput);
+      }
+
+      totalBytesShuffledTillNow += fetchedBytes;
+      logProgress();
+      wakeLoop.signal();
+
+    } finally {
+      lock.unlock();
+    }
+    // NEWTEZ Maybe inform fetchers, in case they have an alternate attempt of 
the same task in their queue.
+  }
+
+
+  private void registerCompletedInput(FetchedInput fetchedInput) {
+    lock.lock();
+    try {
+      maybeInformInputReady(fetchedInput);
+      adjustCompletedInputs(fetchedInput);
+      numFetchedSpills.getAndIncrement();
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  private void maybeInformInputReady(FetchedInput fetchedInput) {
+    lock.lock();
+    try {
+      if (!(fetchedInput instanceof NullFetchedInput)) {
+        LOG.info("maybeInformInputReady");
+        completedInputs.add(fetchedInput);
+      }
+      if (!inputReadyNotificationSent.getAndSet(true)) {
+        // Should eventually be controlled by Inputs which are processing the 
data.
+        LOG.info("maybeInformInputReady InputContext inputIsReady");
+        inputContext.inputIsReady();
+      }
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  private void adjustCompletedInputs(FetchedInput fetchedInput) {
+    lock.lock();
+    try {
+      
completedInputSet.set(fetchedInput.getInputAttemptIdentifier().getInputIdentifier());
+      int numComplete = numCompletedInputs.incrementAndGet();
+      LOG.info("AdjustCompletedInputs, numCompletedInputs:{}", numComplete);
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  private void 
registerCompletedInputForPipelinedShuffle(InputAttemptIdentifier 
srcAttemptIdentifier,
+            FetchedInput fetchedInput) {
+    /**
+     * For pipelinedshuffle it is possible to get multiple spills. Claim 
success only when
+     * all spills pertaining to an attempt are done.
+     */
+    if (!validateInputAttemptForPipelinedShuffle(srcAttemptIdentifier)) {
+      return;
+    }
+
+    int inputIdentifier = srcAttemptIdentifier.getInputIdentifier();
+    ShuffleEventInfo eventInfo = shuffleInfoEventsMap.get(inputIdentifier);
+
+    //for empty partition case
+    if (eventInfo == null && fetchedInput instanceof NullFetchedInput) {
+      eventInfo = new ShuffleEventInfo(srcAttemptIdentifier);
+      shuffleInfoEventsMap.put(inputIdentifier, eventInfo);
+    }
+
+    assert (eventInfo != null);
+    eventInfo.spillProcessed(srcAttemptIdentifier.getSpillEventId());
+    numFetchedSpills.getAndIncrement();
+
+    if (srcAttemptIdentifier.getFetchTypeInfo() == 
InputAttemptIdentifier.SPILL_INFO.FINAL_UPDATE) {
+      eventInfo.setFinalEventId(srcAttemptIdentifier.getSpillEventId());
+    }
+
+    lock.lock();
+    try {
+      /**
+       * When fetch is complete for a spill, add it to completedInputs to 
ensure that it is
+       * available for downstream processing. Final success will be claimed 
only when all
+       * spills are downloaded from the source.
+       */
+      maybeInformInputReady(fetchedInput);
+
+      //check if we downloaded all spills pertaining to this 
InputAttemptIdentifier
+      if (eventInfo.isDone()) {
+        adjustCompletedInputs(fetchedInput);
+        shuffleInfoEventsMap.remove(srcAttemptIdentifier.getInputIdentifier());
+      }
+    } finally {
+      lock.unlock();
+    }
+
+    if (LOG.isTraceEnabled()) {
+      LOG.trace("eventInfo " + eventInfo.toString());
+    }
+  }
+
+  private void reportFatalError(Throwable exception, String message) {
+    LOG.error(message);
+    inputContext.reportFailure(TaskFailureType.NON_FATAL, exception, message);
+  }
+
+  @Override
+  public void fetchFailed(String host, InputAttemptIdentifier 
srcAttemptIdentifier, boolean connectFailed) {
+    // NEWTEZ. Implement logic to report fetch failures after a threshold.
+    // For now, reporting immediately.
+    LOG.info(srcNameTrimmed + ": " + "Fetch failed for src: " + 
srcAttemptIdentifier
+        + "InputIdentifier: " + srcAttemptIdentifier + ", connectFailed: "
+        + connectFailed);
+    failedShufflesCounter.increment(1);
+    inputContext.notifyProgress();
+    if (srcAttemptIdentifier == null) {
+      reportFatalError(null, "Received fetchFailure for an unknown src 
(null)");
+    } else {
+      InputReadErrorEvent readError = InputReadErrorEvent.create(
+          "Fetch failure while fetching from "
+              + TezRuntimeUtils.getTaskAttemptIdentifier(
+              inputContext.getSourceVertexName(),
+              srcAttemptIdentifier.getInputIdentifier(),
+              srcAttemptIdentifier.getAttemptNumber()),
+          srcAttemptIdentifier.getInputIdentifier(),
+          srcAttemptIdentifier.getAttemptNumber());
+      if (maxTimeToWaitForReportMillis > 0) {
+        try {
+          reportLock.lock();
+          failedEvents.merge(readError, 1, (a, b) -> a + b);
+          reportCondition.signal();
+        } finally {
+          reportLock.unlock();
+        }
+      } else {
+        List<Event> events = Lists.newArrayListWithCapacity(1);
+        events.add(readError);
+        inputContext.sendEvents(events);
+      }
+    }
+  }
+  // End of Methods from FetcherCallbackHandler
+
+  @Override
+  public void shutdown() throws InterruptedException {
+    if (Thread.currentThread().isInterrupted()) {
+      // need to cleanup all FetchedInput (DiskFetchedInput, 
LocalDisFetchedInput), lockFile
+      // As of now relying on job cleanup (when all directories would be 
cleared)
+      LOG.info(srcNameTrimmed + ": " + "Thread interrupted. Need to cleanup 
the local dirs");
+    }
+    if (!isShutdown.getAndSet(true)) {
+      // Shut down any pending fetchers
+      LOG.info("Shutting down pending fetchers on source" + srcNameTrimmed + 
": "
+          + rssRunningFetchers.size());
+      lock.lock();
+      try {
+        wakeLoop.signal(); // signal the fetch-scheduler
+        for (RssTezFetcherTask fetcher : rssRunningFetchers) {
+          try {
+            fetcher.shutdown(); // This could be parallelized.
+          } catch (Exception e) {
+            LOG.warn(
+                "Error while stopping fetcher during shutdown. Ignoring and 
continuing. Message={}",
+                e.getMessage());
+          }
+        }
+      } finally {
+        lock.unlock();
+      }
+
+      if (this.schedulerExecutor != null && 
!this.schedulerExecutor.isShutdown()) {
+        this.schedulerExecutor.shutdownNow();
+      }
+      if (this.reporterExecutor != null
+          && !this.reporterExecutor.isShutdown()) {
+        this.reporterExecutor.shutdownNow();
+      }
+      if (this.fetcherExecutor != null && !this.fetcherExecutor.isShutdown()) {
+        this.fetcherExecutor.shutdownNow(); // Interrupts all running fetchers.
+      }
+    }
+  }
+
+  /**
+   * @return true if all of the required inputs have been fetched.
+   */
+  public boolean isAllPartitionFetched() {
+    lock.lock();
+    try {
+      if (!allRssPartition.containsAll(successRssPartitionSet)) {
+        LOG.error("Failed to check partition, all partition:{}, success 
partiton:{}",
+            allRssPartition, successRssPartitionSet);
+      }
+      return isAllInputFetched();
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  /**
+   * @return the next available input, or null if there are no available 
inputs.
+   *         This method will block if there are currently no available inputs,
+   *         but more may become available.
+   */
+  @Override
+  public FetchedInput getNextInput() throws InterruptedException {
+    // Check for no additional inputs
+    FetchedInput fetchedInput = null;
+    if (completedInputs.peek() == null) {
+      while (true) {
+        fetchedInput = completedInputs.poll(2000, TimeUnit.MICROSECONDS);
+        if (fetchedInput != null) {
+          break;
+        } else if (isAllPartitionFetched()) {
+          fetchedInput = completedInputs.poll(100, TimeUnit.MICROSECONDS);
+          LOG.info("GetNextInput, enter isAllPartitionFetched");
+          break;
+        }
+        LOG.info("GetNextInput, out loop");
+      }
+    } else {
+      fetchedInput = completedInputs.take();
+    }
+
+    if (fetchedInput instanceof NullFetchedInput) {
+      LOG.info("getNextInput, NullFetchedInput is null:{}", fetchedInput);
+      fetchedInput = null;
+    }
+    LOG.info("getNextInput, fetchedInput:{}", fetchedInput);
+    return fetchedInput;
+  }
+
+  @Override
+  public int getNumInputs() {
+    return numInputs;
+  }
+
+  @Override
+  public float getNumCompletedInputsFloat() {
+    return numCompletedInputs.floatValue();
+  }
+
+  // End of methods for walking the available inputs
+
+
+  /**
+   * Fake input that is added to the completed input list in case an input 
does not have any data.
+   *
+   */
+  @VisibleForTesting
+  static class NullFetchedInput extends FetchedInput {
+    NullFetchedInput(InputAttemptIdentifier inputAttemptIdentifier) {
+      super(inputAttemptIdentifier, null);
+    }
+
+    @Override
+    public Type getType() {
+      return Type.MEMORY;
+    }
+
+    @Override
+    public long getSize() {
+      return -1;
+    }
+
+    @Override
+    public OutputStream getOutputStream() throws IOException {
+      throw new UnsupportedOperationException("Not supported for 
NullFetchedInput");
+    }
+
+    @Override
+    public InputStream getInputStream() throws IOException {
+      throw new UnsupportedOperationException("Not supported for 
NullFetchedInput");
+    }
+
+    @Override
+    public void commit() throws IOException {
+      throw new UnsupportedOperationException("Not supported for 
NullFetchedInput");
+    }
+
+    @Override
+    public void abort() throws IOException {
+      throw new UnsupportedOperationException("Not supported for 
NullFetchedInput");
+    }
+
+    @Override
+    public void free() {
+      throw new UnsupportedOperationException("Not supported for 
NullFetchedInput");
+    }
+  }
+
+  private final AtomicInteger nextProgressLineEventCount = new 
AtomicInteger(0);
+
+  private void logProgress() {
+    int inputsDone = numCompletedInputs.get();
+    if (inputsDone > nextProgressLineEventCount.get() || inputsDone == 
numInputs) {
+      nextProgressLineEventCount.addAndGet(50);
+      double mbs = (double) totalBytesShuffledTillNow / (1024 * 1024);
+      long secsSinceStart = (System.currentTimeMillis() - startTime) / 1000 + 
1;
+
+      double transferRate = mbs / secsSinceStart;
+      LOG.info("copy(" + inputsDone + " (spillsFetched=" + 
numFetchedSpills.get() + ") of "
+          + numInputs
+          + ". Transfer rate (CumulativeDataFetched/TimeSinceInputStarted)) "
+          + mbpsFormat.format(transferRate) + " MB/s)");
+    }
+  }
+
+
+  private class SchedulerFutureCallback implements FutureCallback<Void> {
+    @Override
+    public void onSuccess(Void result) {
+      LOG.info(srcNameTrimmed + ": " + "Scheduler thread completed");
+    }
+
+    @Override
+    public void onFailure(Throwable t) {
+      if (isShutdown.get()) {
+        if (LOG.isDebugEnabled()) {
+          LOG.debug(srcNameTrimmed + ": " + "Already shutdown. Ignoring error: 
" + t);
+        }
+      } else {
+        LOG.error(srcNameTrimmed + ": " + "Scheduler failed with error: ", t);
+        inputContext.reportFailure(TaskFailureType.NON_FATAL, t, "Shuffle 
Scheduler Failed");
+      }
+    }
+  }
+
+  private class FetchFutureCallback implements FutureCallback<FetchResult> {
+
+    private final RssTezFetcherTask fetcher;
+
+    FetchFutureCallback(RssTezFetcherTask fetcher) {
+      this.fetcher = fetcher;
+    }
+
+    private void doBookKeepingForFetcherComplete() {
+      lock.lock();
+      try {
+        rssRunningFetchers.remove(fetcher);
+        wakeLoop.signal();
+      } finally {
+        lock.unlock();
+      }
+    }
+
+    @Override
+    public void onSuccess(FetchResult result) {
+      LOG.info("FetchFutureCallback success, result:{}, partition:{}", result, 
fetcher.getPartitionId());
+      fetcher.shutdown();
+      if (isShutdown.get()) {
+        if (LOG.isDebugEnabled()) {
+          LOG.debug(srcNameTrimmed + ": " + "Already shutdown. Ignoring event 
from fetcher");
+        }
+      } else {
+        lock.lock();
+        try {
+          successRssPartitionSet.add(fetcher.getPartitionId());
+          runningRssPartitionMap.remove(fetcher.getPartitionId());
+          LOG.info("FetchFutureCallback allRssPartition:{}, 
successRssPartitionSet:{}, runningRssPartitionMap:{}.",
+              allRssPartition, successRssPartitionSet, runningRssPartitionMap);
+          doBookKeepingForFetcherComplete();
+        } finally {
+          lock.unlock();
+        }
+      }
+    }
+
+
+    @Override
+    public void onFailure(Throwable t) {
+      // Unsuccessful - the fetcher may not have shutdown correctly. Try 
shutting it down.
+      fetcher.shutdown();
+      if (isShutdown.get()) {
+        if (LOG.isDebugEnabled()) {
+          LOG.debug(srcNameTrimmed + ": " + "Already shutdown. Ignoring error 
from fetcher: " + t);
+        }
+      } else {
+        LOG.error(srcNameTrimmed + ": " + "Fetcher failed with error: ", t);
+        shuffleError = t;
+        inputContext.reportFailure(TaskFailureType.NON_FATAL, t, "Fetch 
failed");
+        doBookKeepingForFetcherComplete();
+      }
+    }
+  }
+}
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
new file mode 100644
index 00000000..35cedf72
--- /dev/null
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.tez.common.CallableWithNdc;
+import org.apache.tez.common.IdUtils;
+import org.apache.tez.common.InputContextUtils;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.shuffle.FetchResult;
+import org.apache.tez.runtime.library.common.shuffle.FetchedInputAllocator;
+import org.apache.tez.runtime.library.common.shuffle.FetcherCallback;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleReadClient;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.factory.ShuffleClientFactory;
+import org.apache.uniffle.client.request.CreateShuffleReadClientRequest;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.util.UnitConverter;
+
+public class RssTezFetcherTask extends CallableWithNdc<FetchResult> {
+  private static final Logger LOG = 
LoggerFactory.getLogger(RssTezFetcherTask.class);
+
+  private final FetcherCallback fetcherCallback;
+
+  private final InputContext inputContext;
+  private final Configuration conf;
+  private final FetchedInputAllocator inputManager;
+  private final int partition;
+
+  List<InputAttemptIdentifier> inputs;
+  private Set<ShuffleServerInfo> serverInfoSet;
+  Map<Integer, Roaring64NavigableMap> rssAllBlockIdBitmapMap;
+  Map<Integer, Roaring64NavigableMap> rssSuccessBlockIdBitmapMap;
+  private String clientType = null;
+  private final int numPhysicalInputs;
+  private final String appId;
+  private final int dagIdentifier;
+  private final int vertexIndex;
+  private final int reduceId;
+
+  private String storageType;
+  private String basePath;
+  private final int readBufferSize;
+  private final int partitionNumPerRange;
+  private final int partitionNum;
+
+
+  public RssTezFetcherTask(FetcherCallback fetcherCallback, InputContext 
inputContext, Configuration conf,
+            FetchedInputAllocator inputManager, int partition,
+            List<InputAttemptIdentifier> inputs, Set<ShuffleServerInfo> 
serverInfoList,
+            Map<Integer, Roaring64NavigableMap> rssAllBlockIdBitmapMap,
+            Map<Integer, Roaring64NavigableMap> rssSuccessBlockIdBitmapMap,
+            int numPhysicalInputs, int partitionNum) {
+    assert (inputs != null && inputs.size() > 0);
+    this.fetcherCallback = fetcherCallback;
+    this.inputContext = inputContext;
+    this.conf = conf;
+    this.inputManager = inputManager;
+    this.partition = partition;  // partition id to fetch
+    this.inputs = inputs;
+
+    this.serverInfoSet = serverInfoList;
+    this.rssAllBlockIdBitmapMap = rssAllBlockIdBitmapMap;
+    this.rssSuccessBlockIdBitmapMap = rssSuccessBlockIdBitmapMap;
+    this.numPhysicalInputs = numPhysicalInputs;
+    this.partitionNum = partitionNum;
+
+    this.appId = IdUtils.getApplicationAttemptId().toString();
+    this.dagIdentifier = this.inputContext.getDagIdentifier();
+    this.vertexIndex = this.inputContext.getTaskVertexIndex();
+
+    this.reduceId =  this.inputContext.getTaskIndex();
+    LOG.info("RssTezFetcherTask, dagIdentifier:{}, vertexIndex:{}, 
reduceId:{}.", dagIdentifier, vertexIndex, reduceId);
+    clientType = conf.get(RssTezConfig.RSS_CLIENT_TYPE, 
RssTezConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE);
+    this.storageType = conf.get(RssTezConfig.RSS_STORAGE_TYPE, 
RssTezConfig.RSS_STORAGE_TYPE_DEFAULT_VALUE);
+    LOG.info("RssTezFetcherTask storageType:{}", storageType);
+
+    String readBufferSize = conf.get(RssTezConfig.RSS_CLIENT_READ_BUFFER_SIZE,
+        RssTezConfig.RSS_CLIENT_READ_BUFFER_SIZE_DEFAULT_VALUE);
+    this.readBufferSize = (int) 
UnitConverter.byteStringAsBytes(readBufferSize);
+    this.partitionNumPerRange = 
conf.getInt(RssTezConfig.RSS_PARTITION_NUM_PER_RANGE,
+        RssTezConfig.RSS_PARTITION_NUM_PER_RANGE_DEFAULT_VALUE);
+    LOG.info("RssTezFetcherTask fetch partition:{}, with inputs:{}, 
readBufferSize:{}, partitionNumPerRange:{}.",
+        this.partition, inputs, this.readBufferSize, 
this.partitionNumPerRange);
+  }
+
+  @Override
+  protected FetchResult callInternal() throws Exception {
+    // get assigned RSS servers
+    // just get blockIds from RSS servers
+    int shuffleId = InputContextUtils.computeShuffleId(inputContext);
+
+    ShuffleWriteClient writeClient = 
RssTezUtils.createShuffleClient(this.conf);
+    LOG.info("WriteClient getShuffleResult, clientType:{}, serverInfoSet:{}, 
appId:{}, shuffleId:{}, partition:{}",
+        clientType, serverInfoSet, appId, shuffleId, partition);
+    Roaring64NavigableMap blockIdBitmap = writeClient.getShuffleResult(
+        clientType, serverInfoSet, appId, shuffleId, partition);
+    writeClient.close();
+    rssAllBlockIdBitmapMap.put(partition, blockIdBitmap);
+
+    // get map-completion events to generate RSS taskIDs
+    // final RssEventFetcher eventFetcher = new RssEventFetcher(inputs, 
numPhysicalInputs);
+    int appAttemptId = IdUtils.getAppAttemptId();
+    Roaring64NavigableMap taskIdBitmap = RssTezUtils.fetchAllRssTaskIds(
+        new HashSet<>(inputs), numPhysicalInputs,
+        appAttemptId);
+    LOG.info("inputs:{}, num input:{}, appAttemptId:{}, taskIdBitmap:{}",
+        inputs, numPhysicalInputs, appAttemptId, taskIdBitmap);
+
+    LOG.info("In reduce: " + reduceId
+        + ", RSS Tez client has fetched blockIds and taskIds successfully");
+    // start fetcher to fetch blocks from RSS servers
+    if (!taskIdBitmap.isEmpty()) {
+      LOG.info("In reduce: " + reduceId + ", Rss Tez client starts to fetch 
blocks from RSS server");
+      JobConf readerJobConf = getRemoteConf();
+      LOG.info("RssTezFetcherTask storageType:{}", storageType);
+      boolean expectedTaskIdsBitmapFilterEnable = serverInfoSet.size() > 1;
+      CreateShuffleReadClientRequest request = new 
CreateShuffleReadClientRequest(
+          appId,
+          InputContextUtils.computeShuffleId(inputContext),
+          partition,
+          basePath,
+          partitionNumPerRange,
+          partitionNum,
+          blockIdBitmap,
+          taskIdBitmap,
+          new ArrayList<>(serverInfoSet),
+          readerJobConf,
+          expectedTaskIdsBitmapFilterEnable, 
RssTezConfig.toRssConf(this.conf));
+      ShuffleReadClient shuffleReadClient = 
ShuffleClientFactory.getInstance().createShuffleReadClient(request);
+      RssTezFetcher fetcher = new RssTezFetcher(fetcherCallback,
+          inputManager,
+          shuffleReadClient,
+          rssSuccessBlockIdBitmapMap,
+          partition, RssTezConfig.toRssConf(this.conf));
+      fetcher.fetchAllRssBlocks();
+      LOG.info("In reduce: " + partition
+          + ", Rss Tez client fetches blocks from RSS server successfully");
+    }
+    return null;
+  }
+
+  public void shutdown() {
+  }
+
+  private JobConf getRemoteConf() {
+    return new JobConf(conf);
+  }
+
+  public int getPartitionId() {
+    return partition;
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    RssTezFetcherTask that = (RssTezFetcherTask) o;
+    return partition == that.partition
+        && numPhysicalInputs == that.numPhysicalInputs
+        && dagIdentifier == that.dagIdentifier
+        && vertexIndex == that.vertexIndex
+        && reduceId == that.reduceId
+        && Objects.equals(appId, that.appId);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(partition, numPhysicalInputs, dagIdentifier, 
vertexIndex, reduceId, appId);
+  }
+}
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManagerTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManagerTest.java
new file mode 100644
index 00000000..8d49cf15
--- /dev/null
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManagerTest.java
@@ -0,0 +1,341 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.io.compress.CompressionCodec;
+import org.apache.hadoop.security.token.Token;
+import org.apache.tez.common.TezExecutors;
+import org.apache.tez.common.TezRuntimeFrameworkConfigs;
+import org.apache.tez.common.TezSharedExecutor;
+import org.apache.tez.common.UmbilicalUtils;
+import org.apache.tez.common.counters.TezCounters;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.http.HttpConnectionParams;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.ExecutionContext;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.api.events.DataMovementEvent;
+import org.apache.tez.runtime.api.events.InputReadErrorEvent;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.shuffle.FetchResult;
+import org.apache.tez.runtime.library.common.shuffle.FetchedInput;
+import org.apache.tez.runtime.library.common.shuffle.FetchedInputAllocator;
+import org.apache.tez.runtime.library.common.shuffle.Fetcher;
+import org.apache.tez.runtime.library.common.shuffle.InputHost;
+import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
+import 
org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.DataMovementEventPayloadProto;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.mockito.ArgumentCaptor;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.atLeast;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class RssShuffleManagerTest {
+  private static final String FETCHER_HOST = "localhost";
+  private static final int PORT = 8080;
+  private static final String PATH_COMPONENT = "attempttmp";
+  private static final Configuration conf = new Configuration();
+  private static TezExecutors sharedExecutor;
+
+  @BeforeAll
+  public static void setup() {
+    sharedExecutor = new TezSharedExecutor(conf);
+  }
+
+  @AfterAll
+  public static void cleanup() {
+    sharedExecutor.shutdownNow();
+  }
+
+  private InputContext createInputContext() throws IOException {
+    DataOutputBuffer portDob = new DataOutputBuffer();
+    portDob.writeInt(PORT);
+    final ByteBuffer shuffleMetaData = ByteBuffer.wrap(portDob.getData(), 0, 
portDob.getLength());
+    portDob.close();
+
+    ExecutionContext executionContext = mock(ExecutionContext.class);
+    doReturn(FETCHER_HOST).when(executionContext).getHostName();
+
+    InputContext inputContext = mock(InputContext.class);
+    doReturn(new TezCounters()).when(inputContext).getCounters();
+    doReturn(shuffleMetaData).when(inputContext)
+        
.getServiceProviderMetaData(conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID,
+            TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT));
+    doReturn(executionContext).when(inputContext).getExecutionContext();
+    doReturn("Map 1").when(inputContext).getSourceVertexName();
+    doReturn("Reducer 1").when(inputContext).getTaskVertexName();
+    
when(inputContext.getUniqueIdentifier()).thenReturn("attempt_1685094627632_0157_1_01_000000_0_10006");
+    return inputContext;
+  }
+
+  @Test
+  @Timeout(value = 50000, unit = TimeUnit.MILLISECONDS)
+  public void testUseSharedExecutor() throws Exception {
+    try (MockedStatic<ShuffleUtils> shuffleUtils = 
Mockito.mockStatic(ShuffleUtils.class)) {
+      shuffleUtils.when(() -> 
ShuffleUtils.deserializeShuffleProviderMetaData(any())).thenReturn(4);
+      shuffleUtils.when(() -> 
ShuffleUtils.getHttpConnectionParams(any())).thenReturn(
+          new HttpConnectionParams(false, 1000, 5000, 1000, 104857600, false, 
null));
+
+      try (MockedStatic<UmbilicalUtils> umbilicalUtils = 
Mockito.mockStatic(UmbilicalUtils.class)) {
+        Map<Integer, List<ShuffleServerInfo>> workers = new HashMap<>();
+        workers.put(1, ImmutableList.of(new ShuffleServerInfo("127.0.0.1", 
2181)));
+        umbilicalUtils.when(() -> UmbilicalUtils.requestShuffleServer(any(), 
any(), any(), anyInt()))
+            .thenReturn(workers);
+
+        InputContext inputContext = createInputContext();
+        createShuffleManager(inputContext, 2);
+        verify(inputContext, 
times(0)).createTezFrameworkExecutorService(anyInt(), anyString());
+      }
+    }
+  }
+
+  @Test
+  @Timeout(value = 20000, unit = TimeUnit.MILLISECONDS)
+  public void testProgressWithEmptyPendingHosts() throws Exception {
+    try (MockedStatic<ShuffleUtils> shuffleUtils = 
Mockito.mockStatic(ShuffleUtils.class)) {
+      shuffleUtils.when(() -> 
ShuffleUtils.deserializeShuffleProviderMetaData(any())).thenReturn(4);
+      HttpConnectionParams params = new HttpConnectionParams(false, 1000, 5000,
+          1000, 1024 * 1024 * 100, false, null);
+      shuffleUtils.when(() -> 
ShuffleUtils.getHttpConnectionParams(any())).thenReturn(params);
+      InputContext inputContext = createInputContext();
+      final ShuffleManager shuffleManager = 
spy(createShuffleManager(inputContext, 1));
+      Thread schedulerGetHostThread = new Thread(new Runnable() {
+        @Override
+        public void run() {
+          try (MockedStatic<UmbilicalUtils> umbilicalUtils = 
Mockito.mockStatic(UmbilicalUtils.class)) {
+            Map<Integer, List<ShuffleServerInfo>> workers = new HashMap<>();
+            workers.put(1, ImmutableList.of(new ShuffleServerInfo("127.0.0.1", 
2181)));
+            umbilicalUtils.when(() -> 
UmbilicalUtils.requestShuffleServer(any(), any(), any(), anyInt()))
+                .thenReturn(workers);
+            try {
+              shuffleManager.run();
+            } catch (Exception e) {
+              e.printStackTrace();
+            }
+          }
+        }
+      });
+      schedulerGetHostThread.start();
+      Thread.currentThread().sleep(1000 * 3 + 1000);
+      schedulerGetHostThread.interrupt();
+      verify(inputContext, atLeast(3)).notifyProgress();
+    }
+  }
+
+  @Test
+  @Timeout(value = 2000000, unit = TimeUnit.MILLISECONDS)
+  public void testFetchFailed() throws Exception {
+    try (MockedStatic<ShuffleUtils> shuffleUtils = 
Mockito.mockStatic(ShuffleUtils.class)) {
+      shuffleUtils.when(() -> 
ShuffleUtils.deserializeShuffleProviderMetaData(any())).thenReturn(4);
+      shuffleUtils.when(() -> 
ShuffleUtils.getHttpConnectionParams(any())).thenReturn(
+          new HttpConnectionParams(false, 1000, 5000, 1000, 1024 * 1024 * 100, 
false, null));
+
+      InputContext inputContext = createInputContext();
+      final ShuffleManager shuffleManager = 
spy(createShuffleManager(inputContext, 1));
+      Thread schedulerGetHostThread = new Thread(new Runnable() {
+        @Override
+        public void run() {
+          try (MockedStatic<UmbilicalUtils> umbilicalUtils = 
Mockito.mockStatic(UmbilicalUtils.class)) {
+            Map<Integer, List<ShuffleServerInfo>> workers = new HashMap<>();
+            workers.put(1, ImmutableList.of(new ShuffleServerInfo("127.0.0.1", 
2181)));
+            umbilicalUtils.when(() -> 
UmbilicalUtils.requestShuffleServer(any(), any(), any(), anyInt()))
+                .thenReturn(workers);
+            try {
+              shuffleManager.run();
+            } catch (Exception e) {
+              e.printStackTrace();
+            }
+          }
+        }
+      });
+      InputAttemptIdentifier inputAttemptIdentifier  = new 
InputAttemptIdentifier(1, 1);
+
+      schedulerGetHostThread.start();
+      Thread.sleep(1000);
+      shuffleManager.fetchFailed("host1", inputAttemptIdentifier, false);
+      Thread.sleep(1000);
+
+      ArgumentCaptor<List> captor = ArgumentCaptor.forClass(List.class);
+      verify(inputContext, times(1))
+          .sendEvents(captor.capture());
+      assertEquals(captor.getAllValues().size(), 1);
+      List<Event> capturedList = captor.getAllValues().get(0);
+      assertEquals(capturedList.size(), 1);
+      InputReadErrorEvent inputEvent = (InputReadErrorEvent) 
capturedList.get(0);
+
+      shuffleManager.fetchFailed("host1", inputAttemptIdentifier, false);
+      shuffleManager.fetchFailed("host1", inputAttemptIdentifier, false);
+
+      // Wait more than five seconds for the batch to go out
+      Thread.sleep(5000);
+      captor = ArgumentCaptor.forClass(List.class);
+      assertEquals(capturedList.size(), 1);
+    }
+  }
+
+  private ShuffleManagerForTest createShuffleManager(
+      InputContext inputContext, int expectedNumOfPhysicalInputs)
+      throws IOException {
+    Path outDirBase = new Path(".", "outDir");
+    String[] outDirs = new String[] { outDirBase.toString() };
+    doReturn(outDirs).when(inputContext).getWorkDirs();
+    conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, 
inputContext.getWorkDirs());
+
+    DataOutputBuffer out = new DataOutputBuffer();
+    Token<JobTokenIdentifier> token = new Token<JobTokenIdentifier>(new 
JobTokenIdentifier(),
+        new JobTokenSecretManager(null));
+    token.write(out);
+    doReturn(ByteBuffer.wrap(out.getData())).when(inputContext)
+        .getServiceConsumerMetaData(
+            conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID,
+                TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT));
+
+    FetchedInputAllocator inputAllocator = mock(FetchedInputAllocator.class);
+    return new ShuffleManagerForTest(inputContext, conf,
+        expectedNumOfPhysicalInputs, 1024, false, -1, null, inputAllocator);
+  }
+
+  private Event createDataMovementEvent(String host, int srcIndex, int 
targetIndex) {
+    DataMovementEventPayloadProto.Builder builder = 
DataMovementEventPayloadProto.newBuilder();
+    builder.setHost(host);
+    builder.setPort(PORT);
+    builder.setPathComponent(PATH_COMPONENT);
+    Event dme = DataMovementEvent.create(srcIndex, targetIndex, 0,
+        builder.build().toByteString().asReadOnlyByteBuffer());
+    return dme;
+  }
+
+  private static class ShuffleManagerForTest extends RssShuffleManager {
+    ShuffleManagerForTest(InputContext inputContext, Configuration conf, int 
numInputs, int bufferSize,
+          boolean ifileReadAheadEnabled, int ifileReadAheadLength, 
CompressionCodec codec,
+          FetchedInputAllocator inputAllocator) throws IOException {
+      super(inputContext, conf, numInputs, bufferSize, ifileReadAheadEnabled,
+          ifileReadAheadLength, codec, inputAllocator);
+    }
+
+    @Override
+    Fetcher constructFetcherForHost(InputHost inputHost, Configuration conf) {
+      final Fetcher fetcher = spy(super.constructFetcherForHost(inputHost, 
conf));
+      final FetchResult mockFetcherResult = mock(FetchResult.class);
+      try {
+        doAnswer(new Answer<FetchResult>() {
+          @Override
+          public FetchResult answer(InvocationOnMock invocation) throws 
Throwable {
+            for (InputAttemptIdentifier input : fetcher.getSrcAttempts()) {
+              ShuffleManagerForTest.this.fetchSucceeded(
+                  fetcher.getHost(), input, new TestFetchedInput(input), 0, 0,
+                  0);
+            }
+            return mockFetcherResult;
+          }
+        }).when(fetcher).callInternal();
+      } catch (Exception e) {
+        //ignore
+      }
+      return fetcher;
+    }
+
+    public int getNumOfCompletedInputs() {
+      return completedInputSet.cardinality();
+    }
+
+    boolean isFetcherExecutorShutdown() {
+      return fetcherExecutor.isShutdown();
+    }
+  }
+
+  /**
+   * Fake input that is added to the completed input list in case an input 
does not have any data.
+   *
+   */
+  @VisibleForTesting
+  static class TestFetchedInput extends FetchedInput {
+
+    TestFetchedInput(InputAttemptIdentifier inputAttemptIdentifier) {
+      super(inputAttemptIdentifier, null);
+    }
+
+    @Override
+    public long getSize() {
+      return -1;
+    }
+
+    @Override
+    public Type getType() {
+      return Type.MEMORY;
+    }
+
+    @Override
+    public OutputStream getOutputStream() throws IOException {
+      return null;
+    }
+
+    @Override
+    public InputStream getInputStream() throws IOException {
+      return null;
+    }
+
+    @Override
+    public void commit() throws IOException {
+    }
+
+    @Override
+    public void abort() throws IOException {
+    }
+
+    @Override
+    public void free() {
+    }
+  }
+}

Reply via email to