This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new fc238005b [CELEBORN-1144] Batch OpenStream RPCs
fc238005b is described below

commit fc238005bd8482ea41612aae6aae7e8f16f918f5
Author: zky.zhoukeyong <[email protected]>
AuthorDate: Mon Mar 25 16:25:05 2024 +0800

    [CELEBORN-1144] Batch OpenStream RPCs
    
    ### What changes were proposed in this pull request?
    Batch OpenStream RPCs by Worker to avoid too many RPCs.
    
    ### Why are the changes needed?
    ditto
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Passes GA and Manual tests.
    
    Closes #2362 from waitinfuture/1144.
    
    Authored-by: zky.zhoukeyong <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../flink/readclient/FlinkShuffleClientImpl.java   |   2 +-
 .../shuffle/celeborn/CelebornShuffleReader.scala   | 170 ++++++++++++----
 .../org/apache/celeborn/client/ShuffleClient.java  |  15 ++
 .../apache/celeborn/client/ShuffleClientImpl.java  |  25 ++-
 .../celeborn/client/read/CelebornInputStream.java  |  64 +++---
 .../client/read/WorkerPartitionReader.java         |  29 +--
 .../apache/celeborn/client/DummyShuffleClient.java |  17 ++
 .../celeborn/client/WithShuffleClientSuite.scala   |  26 ++-
 .../common/network/protocol/TransportMessage.java  |  28 +--
 .../common/protocol/message/StatusCode.java        |   3 +-
 common/src/main/proto/TransportMessages.proto      |  21 ++
 .../service/deploy/worker/FetchHandler.scala       | 224 ++++++++++++++-------
 .../service/deploy/cluster/ReadWriteTestBase.scala |  13 +-
 13 files changed, 448 insertions(+), 189 deletions(-)

diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
index 6ced153c4..bb599c93e 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
@@ -171,7 +171,7 @@ public class FlinkShuffleClientImpl extends 
ShuffleClientImpl {
   }
 
   @Override
-  protected ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
+  public ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
       throws CelebornIOException {
     ReduceFileGroups reduceFileGroups =
         reduceFileGroupsMap.computeIfAbsent(shuffleId, (id) -> new 
ReduceFileGroups());
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index dc3dcb4bb..bb66a3cef 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -18,9 +18,12 @@
 package org.apache.spark.shuffle.celeborn
 
 import java.io.IOException
+import java.util
 import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit}
 import java.util.concurrent.atomic.AtomicReference
 
+import scala.collection.JavaConverters._
+
 import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, 
TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.serializer.SerializerInstance
@@ -33,7 +36,11 @@ import org.apache.celeborn.client.ShuffleClient
 import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.exception.{CelebornIOException, 
PartitionUnRetryAbleException}
-import org.apache.celeborn.common.util.{ExceptionMaker, ThreadUtils}
+import org.apache.celeborn.common.network.client.TransportClient
+import org.apache.celeborn.common.network.protocol.TransportMessage
+import org.apache.celeborn.common.protocol.{MessageType, PartitionLocation, 
PbOpenStreamList, PbOpenStreamListResponse, PbStreamHandler}
+import org.apache.celeborn.common.protocol.message.StatusCode
+import org.apache.celeborn.common.util.{ExceptionMaker, JavaUtils, 
ThreadUtils, Utils}
 
 class CelebornShuffleReader[K, C](
     handle: CelebornShuffleHandle[K, _, C],
@@ -107,60 +114,139 @@ class CelebornShuffleReader[K, C](
       }
     }
 
-    val streams = new ConcurrentHashMap[Integer, CelebornInputStream]()
-    (startPartition until endPartition).map(partitionId => {
+    val startTime = System.currentTimeMillis()
+    val fetchTimeoutMs = conf.clientFetchTimeoutMs
+    val localFetchEnabled = conf.enableReadLocalShuffleFile
+    val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
+    // startPartition is irrelevant
+    val fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
+    // host-port -> (TransportClient, PartitionLocation Array, 
PbOpenStreamList)
+    val workerRequestMap = new util.HashMap[
+      String,
+      (TransportClient, util.ArrayList[PartitionLocation], 
PbOpenStreamList.Builder)]()
+
+    var partCnt = 0
+
+    (startPartition until endPartition).foreach { partitionId =>
+      if (fileGroups.partitionGroups.containsKey(partitionId)) {
+        fileGroups.partitionGroups.get(partitionId).asScala.foreach { location 
=>
+          partCnt += 1
+          val hostPort = location.hostAndFetchPort
+          if (!workerRequestMap.containsKey(hostPort)) {
+            val client = shuffleClient.getDataClientFactory().createClient(
+              location.getHost,
+              location.getFetchPort)
+            val pbOpenStreamList = PbOpenStreamList.newBuilder()
+            pbOpenStreamList.setShuffleKey(shuffleKey)
+            workerRequestMap.put(
+              hostPort,
+              (client, new util.ArrayList[PartitionLocation], 
pbOpenStreamList))
+          }
+          val (_, locArr, pbOpenStreamListBuilder) = 
workerRequestMap.get(hostPort)
+
+          locArr.add(location)
+          pbOpenStreamListBuilder.addFileName(location.getFileName)
+            .addStartIndex(startMapIndex)
+            .addEndIndex(endMapIndex)
+          pbOpenStreamListBuilder.addReadLocalShuffle(localFetchEnabled)
+        }
+      }
+    }
+
+    val locationStreamHandlerMap: ConcurrentHashMap[PartitionLocation, 
PbStreamHandler] =
+      JavaUtils.newConcurrentHashMap()
+
+    val futures = workerRequestMap.values().asScala.map { entry =>
       streamCreatorPool.submit(new Runnable {
         override def run(): Unit = {
-          if (exceptionRef.get() == null) {
+          val (client, locArr, pbOpenStreamListBuilder) = entry
+          val msg = new TransportMessage(
+            MessageType.BATCH_OPEN_STREAM,
+            pbOpenStreamListBuilder.build().toByteArray)
+          val pbOpenStreamListResponse =
             try {
-              val inputStream = shuffleClient.readPartition(
-                shuffleId,
-                handle.shuffleId,
-                partitionId,
-                context.attemptNumber(),
-                startMapIndex,
-                endMapIndex,
-                if (throwsFetchFailure) exceptionMaker else null,
-                metricsCallback)
-              streams.put(partitionId, inputStream)
+              val response = client.sendRpcSync(msg.toByteBuffer, 
fetchTimeoutMs)
+              
TransportMessage.fromByteBuffer(response).getParsedPayload[PbOpenStreamListResponse]
             } catch {
-              case e: IOException =>
-                logError(s"Exception caught when readPartition $partitionId!", 
e)
-                exceptionRef.compareAndSet(null, e)
-              case e: Throwable =>
-                logError(s"Non IOException caught when readPartition 
$partitionId!", e)
-                exceptionRef.compareAndSet(null, new CelebornIOException(e))
+              case _: Exception => null
+            }
+          if (pbOpenStreamListResponse != null) {
+            0 until locArr.size() foreach { idx =>
+              val streamHandlerOpt = 
pbOpenStreamListResponse.getStreamHandlerOptList.get(idx)
+              if (streamHandlerOpt.getStatus == StatusCode.SUCCESS.getValue) {
+                locationStreamHandlerMap.put(locArr.get(idx), 
streamHandlerOpt.getStreamHandler)
+              }
             }
           }
         }
       })
-    })
+    }.toList
+    // wait for all futures to complete
+    futures.foreach(f => f.get())
+    val end = System.currentTimeMillis()
+    logInfo(s"BatchOpenStream for $partCnt cost ${end - startTime}ms")
+
+    def createInputStream(partitionId: Int): CelebornInputStream = {
+      val locations =
+        if (fileGroups.partitionGroups.containsKey(partitionId)) {
+          new util.ArrayList(fileGroups.partitionGroups.get(partitionId))
+        } else new util.ArrayList[PartitionLocation]()
+      val streamHandlers =
+        if (locations != null) {
+          val streamHandlerArr = new 
util.ArrayList[PbStreamHandler](locations.size())
+          locations.asScala.foreach { loc =>
+            streamHandlerArr.add(locationStreamHandlerMap.get(loc))
+          }
+          streamHandlerArr
+        } else null
+      if (exceptionRef.get() == null) {
+        try {
+          shuffleClient.readPartition(
+            shuffleId,
+            handle.shuffleId,
+            partitionId,
+            context.attemptNumber(),
+            startMapIndex,
+            endMapIndex,
+            if (throwsFetchFailure) exceptionMaker else null,
+            locations,
+            streamHandlers,
+            fileGroups.mapAttempts,
+            metricsCallback)
+        } catch {
+          case e: IOException =>
+            logError(s"Exception caught when readPartition $partitionId!", e)
+            exceptionRef.compareAndSet(null, e)
+            null
+          case e: Throwable =>
+            logError(s"Non IOException caught when readPartition 
$partitionId!", e)
+            exceptionRef.compareAndSet(null, new CelebornIOException(e))
+            null
+        }
+      } else null
+    }
 
     val recordIter = (startPartition until 
endPartition).iterator.map(partitionId => {
       if (handle.numMappers > 0) {
         val startFetchWait = System.nanoTime()
-        var inputStream: CelebornInputStream = streams.get(partitionId)
-        while (inputStream == null) {
-          if (exceptionRef.get() != null) {
-            exceptionRef.get() match {
-              case ce @ (_: CelebornIOException | _: 
PartitionUnRetryAbleException) =>
-                if (throwsFetchFailure &&
-                  shuffleClient.reportShuffleFetchFailure(handle.shuffleId, 
shuffleId)) {
-                  throw new FetchFailedException(
-                    null,
-                    handle.shuffleId,
-                    -1,
-                    -1,
-                    partitionId,
-                    SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + 
"/" + shuffleId,
-                    ce)
-                } else
-                  throw ce
-              case e => throw e
-            }
+        val inputStream: CelebornInputStream = createInputStream(partitionId)
+        if (exceptionRef.get() != null) {
+          exceptionRef.get() match {
+            case ce @ (_: CelebornIOException | _: 
PartitionUnRetryAbleException) =>
+              if (throwsFetchFailure &&
+                shuffleClient.reportShuffleFetchFailure(handle.shuffleId, 
shuffleId)) {
+                throw new FetchFailedException(
+                  null,
+                  handle.shuffleId,
+                  -1,
+                  -1,
+                  partitionId,
+                  SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" 
+ shuffleId,
+                  ce)
+              } else
+                throw ce
+            case e => throw e
           }
-          Thread.sleep(50)
-          inputStream = streams.get(partitionId)
         }
         metricsCallback.incReadTime(
           TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait))
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index b2e694806..aef173a62 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -18,6 +18,7 @@
 package org.apache.celeborn.client;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.LongAdder;
 
@@ -28,8 +29,11 @@ import org.slf4j.LoggerFactory;
 import org.apache.celeborn.client.read.CelebornInputStream;
 import org.apache.celeborn.client.read.MetricsCallback;
 import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.exception.CelebornIOException;
 import org.apache.celeborn.common.identity.UserIdentifier;
+import org.apache.celeborn.common.network.client.TransportClientFactory;
 import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.PbStreamHandler;
 import org.apache.celeborn.common.rpc.RpcEndpointRef;
 import org.apache.celeborn.common.util.CelebornHadoopUtils;
 import org.apache.celeborn.common.util.ExceptionMaker;
@@ -197,6 +201,9 @@ public abstract class ShuffleClient {
   // Cleanup states of the map task
   public abstract void cleanup(int shuffleId, int mapId, int attemptId);
 
+  public abstract ShuffleClientImpl.ReduceFileGroups updateFileGroup(int 
shuffleId, int partitionId)
+      throws CelebornIOException;
+
   // Reduce side read partition which is deduplicated by 
mapperId+mapperAttemptNum+batchId, batchId
   // is a self-incrementing variable hidden in the implementation when sending 
data.
   /**
@@ -227,6 +234,9 @@ public abstract class ShuffleClient {
         startMapIndex,
         endMapIndex,
         null,
+        null,
+        null,
+        null,
         metricsCallback);
   }
 
@@ -238,6 +248,9 @@ public abstract class ShuffleClient {
       int startMapIndex,
       int endMapIndex,
       ExceptionMaker exceptionMaker,
+      ArrayList<PartitionLocation> locations,
+      ArrayList<PbStreamHandler> streamHandlers,
+      int[] mapAttempts,
       MetricsCallback metricsCallback)
       throws IOException;
 
@@ -261,4 +274,6 @@ public abstract class ShuffleClient {
    * incorrect shuffle data can be fetched in re-run tasks
    */
   public abstract boolean reportShuffleFetchFailure(int appShuffleId, int 
shuffleId);
+
+  public abstract TransportClientFactory getDataClientFactory();
 }
diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index c2fe19184..0494d73e9 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -1646,7 +1646,7 @@ public class ShuffleClientImpl extends ShuffleClient {
     }
   }
 
-  protected ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
+  public ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
       throws CelebornIOException {
     if (reduceFileGroupsMap.containsKey(shuffleId)) {
       return reduceFileGroupsMap.get(shuffleId);
@@ -1679,16 +1679,28 @@ public class ShuffleClientImpl extends ShuffleClient {
       int startMapIndex,
       int endMapIndex,
       ExceptionMaker exceptionMaker,
+      ArrayList<PartitionLocation> locations,
+      ArrayList<PbStreamHandler> streamHandlers,
+      int[] mapAttempts,
       MetricsCallback metricsCallback)
       throws IOException {
     if (partitionId == Utils$.MODULE$.UNKNOWN_APP_SHUFFLE_ID()) {
       logger.warn("Shuffle data is empty for shuffle {}: 
UNKNOWN_APP_SHUFFLE_ID.", shuffleId);
       return CelebornInputStream.empty();
     }
-    ReduceFileGroups fileGroups = updateFileGroup(shuffleId, partitionId);
 
-    if (fileGroups.partitionGroups.isEmpty()
-        || !fileGroups.partitionGroups.containsKey(partitionId)) {
+    // When `mapAttempts` is not null, it's guaranteed that the code path 
comes from
+    // CelebornShuffleReader, which means `updateFileGroup` is already called 
and
+    // batch open stream has been tried
+    if (mapAttempts == null) {
+      ReduceFileGroups fileGroups = updateFileGroup(shuffleId, partitionId);
+      mapAttempts = fileGroups.mapAttempts;
+      if (fileGroups.partitionGroups.containsKey(partitionId)) {
+        locations = new ArrayList(fileGroups.partitionGroups.get(partitionId));
+      }
+    }
+
+    if (locations == null || locations.size() == 0) {
       logger.warn("Shuffle data is empty for shuffle {} partition {}.", 
shuffleId, partitionId);
       return CelebornInputStream.empty();
     } else {
@@ -1698,8 +1710,9 @@ public class ShuffleClientImpl extends ShuffleClient {
           conf,
           dataClientFactory,
           shuffleKey,
-          fileGroups.partitionGroups.get(partitionId).toArray(new 
PartitionLocation[0]),
-          fileGroups.mapAttempts,
+          locations,
+          streamHandlers,
+          mapAttempts,
           attemptNumber,
           startMapIndex,
           endMapIndex,
diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java 
b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
index a10239bf5..4312272dd 100644
--- 
a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
+++ 
b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
@@ -25,6 +25,8 @@ import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.LongAdder;
 
+import scala.Tuple2;
+
 import com.google.common.util.concurrent.Uninterruptibles;
 import io.netty.buffer.ByteBuf;
 import org.roaringbitmap.RoaringBitmap;
@@ -37,10 +39,7 @@ import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.exception.CelebornIOException;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
 import org.apache.celeborn.common.network.util.TransportConf;
-import org.apache.celeborn.common.protocol.CompressionCodec;
-import org.apache.celeborn.common.protocol.PartitionLocation;
-import org.apache.celeborn.common.protocol.StorageInfo;
-import org.apache.celeborn.common.protocol.TransportModuleConstants;
+import org.apache.celeborn.common.protocol.*;
 import org.apache.celeborn.common.unsafe.Platform;
 import org.apache.celeborn.common.util.ExceptionMaker;
 import org.apache.celeborn.common.util.Utils;
@@ -52,7 +51,8 @@ public abstract class CelebornInputStream extends InputStream 
{
       CelebornConf conf,
       TransportClientFactory clientFactory,
       String shuffleKey,
-      PartitionLocation[] locations,
+      ArrayList<PartitionLocation> locations,
+      ArrayList<PbStreamHandler> streamHandlers,
       int[] attempts,
       int attemptNumber,
       int startMapIndex,
@@ -65,7 +65,7 @@ public abstract class CelebornInputStream extends InputStream 
{
       ExceptionMaker exceptionMaker,
       MetricsCallback metricsCallback)
       throws IOException {
-    if (locations == null || locations.length == 0) {
+    if (locations == null || locations.size() == 0) {
       return emptyInputStream;
     } else {
       return new CelebornInputStreamImpl(
@@ -73,6 +73,7 @@ public abstract class CelebornInputStream extends InputStream 
{
           clientFactory,
           shuffleKey,
           locations,
+          streamHandlers,
           attempts,
           attemptNumber,
           startMapIndex,
@@ -124,7 +125,8 @@ public abstract class CelebornInputStream extends 
InputStream {
     private final CelebornConf conf;
     private final TransportClientFactory clientFactory;
     private final String shuffleKey;
-    private PartitionLocation[] locations;
+    private ArrayList<PartitionLocation> locations;
+    private ArrayList<PbStreamHandler> streamHandlers;
     private int[] attempts;
     private final int attemptNumber;
     private final int startMapIndex;
@@ -174,7 +176,8 @@ public abstract class CelebornInputStream extends 
InputStream {
         CelebornConf conf,
         TransportClientFactory clientFactory,
         String shuffleKey,
-        PartitionLocation[] locations,
+        ArrayList<PartitionLocation> locations,
+        ArrayList<PbStreamHandler> streamHandlers,
         int[] attempts,
         int attemptNumber,
         int startMapIndex,
@@ -190,7 +193,10 @@ public abstract class CelebornInputStream extends 
InputStream {
       this.conf = conf;
       this.clientFactory = clientFactory;
       this.shuffleKey = shuffleKey;
-      this.locations = (PartitionLocation[]) Utils.randomizeInPlace(locations, 
RAND);
+      this.locations = locations;
+      if (streamHandlers != null && streamHandlers.size() == locations.size()) 
{
+        this.streamHandlers = streamHandlers;
+      }
       this.attempts = attempts;
       this.attemptNumber = attemptNumber;
       this.startMapIndex = startMapIndex;
@@ -242,24 +248,25 @@ public abstract class CelebornInputStream extends 
InputStream {
       return true;
     }
 
-    private PartitionLocation nextReadableLocation() {
-      int locationCount = locations.length;
+    private Tuple2<PartitionLocation, PbStreamHandler> nextReadableLocation() {
+      int locationCount = locations.size();
       if (fileIndex >= locationCount) {
         return null;
       }
-      PartitionLocation currentLocation = locations[fileIndex];
+      PartitionLocation currentLocation = locations.get(fileIndex);
       while (skipLocation(startMapIndex, endMapIndex, currentLocation)) {
         skipCount.increment();
         fileIndex++;
         if (fileIndex == locationCount) {
           return null;
         }
-        currentLocation = locations[fileIndex];
+        currentLocation = locations.get(fileIndex);
       }
 
       fetchChunkRetryCnt = 0;
 
-      return currentLocation;
+      return new Tuple2(
+          currentLocation, streamHandlers == null ? null : 
streamHandlers.get(fileIndex));
     }
 
     private void moveToNextReader(boolean fetchChunk) throws IOException {
@@ -267,11 +274,11 @@ public abstract class CelebornInputStream extends 
InputStream {
         currentReader.close();
         currentReader = null;
       }
-      PartitionLocation currentLocation = nextReadableLocation();
+      Tuple2<PartitionLocation, PbStreamHandler> currentLocation = 
nextReadableLocation();
       if (currentLocation == null) {
         return;
       }
-      currentReader = createReaderWithRetry(currentLocation);
+      currentReader = createReaderWithRetry(currentLocation._1, 
currentLocation._2);
       fileIndex++;
       while (!currentReader.hasNext()) {
         currentReader.close();
@@ -280,7 +287,7 @@ public abstract class CelebornInputStream extends 
InputStream {
         if (currentLocation == null) {
           return;
         }
-        currentReader = createReaderWithRetry(currentLocation);
+        currentReader = createReaderWithRetry(currentLocation._1, 
currentLocation._2);
         fileIndex++;
       }
       if (fetchChunk) {
@@ -332,14 +339,15 @@ public abstract class CelebornInputStream extends 
InputStream {
       return connectException || rpcTimeout || fetchChunkTimeout;
     }
 
-    private PartitionReader createReaderWithRetry(PartitionLocation location) 
throws IOException {
+    private PartitionReader createReaderWithRetry(
+        PartitionLocation location, PbStreamHandler pbStreamHandler) throws 
IOException {
       Exception lastException = null;
       while (fetchChunkRetryCnt < fetchChunkMaxRetry) {
         try {
           if (isExcluded(location)) {
             throw new CelebornIOException("Fetch data from excluded worker! " 
+ location);
           }
-          return createReader(location, fetchChunkRetryCnt, 
fetchChunkMaxRetry);
+          return createReader(location, pbStreamHandler, fetchChunkRetryCnt, 
fetchChunkMaxRetry);
         } catch (Exception e) {
           lastException = e;
           excludeFailedLocation(location, e);
@@ -404,7 +412,7 @@ public abstract class CelebornInputStream extends 
InputStream {
               if (fetchChunkRetryCnt % 2 == 0) {
                 Uninterruptibles.sleepUninterruptibly(retryWaitMs, 
TimeUnit.MILLISECONDS);
               }
-              currentReader = 
createReaderWithRetry(currentReader.getLocation().getPeer());
+              currentReader = 
createReaderWithRetry(currentReader.getLocation().getPeer(), null);
             } else {
               logger.warn(
                   "Fetch chunk failed {}/{} times for location {}",
@@ -413,7 +421,7 @@ public abstract class CelebornInputStream extends 
InputStream {
                   currentReader.getLocation(),
                   e);
               Uninterruptibles.sleepUninterruptibly(retryWaitMs, 
TimeUnit.MILLISECONDS);
-              currentReader = 
createReaderWithRetry(currentReader.getLocation());
+              currentReader = 
createReaderWithRetry(currentReader.getLocation(), null);
             }
           }
         }
@@ -422,11 +430,14 @@ public abstract class CelebornInputStream extends 
InputStream {
     }
 
     private PartitionReader createReader(
-        PartitionLocation location, int fetchChunkRetryCnt, int 
fetchChunkMaxRetry)
+        PartitionLocation location,
+        PbStreamHandler pbStreamHandler,
+        int fetchChunkRetryCnt,
+        int fetchChunkMaxRetry)
         throws IOException, InterruptedException {
       if (!location.hasPeer()) {
         logger.debug("Partition {} has only one partition replica.", location);
-      } else if (attemptNumber % 2 == 1) {
+      } else if (pbStreamHandler == null && attemptNumber % 2 == 1) {
         location = location.getPeer();
         logger.debug("Read peer {} for attempt {}.", location, attemptNumber);
       }
@@ -446,6 +457,7 @@ public abstract class CelebornInputStream extends 
InputStream {
                 conf,
                 shuffleKey,
                 location,
+                pbStreamHandler,
                 clientFactory,
                 startMapIndex,
                 endMapIndex,
@@ -513,7 +525,7 @@ public abstract class CelebornInputStream extends 
InputStream {
     @Override
     public synchronized void close() {
       if (!closed) {
-        int locationsCount = locations.length;
+        int locationsCount = locations.size();
         logger.debug(
             "AppShuffleId {}, shuffleId {}, partitionId {}, total location 
count {}, read {}, skip {}",
             appShuffleId,
@@ -556,7 +568,7 @@ public abstract class CelebornInputStream extends 
InputStream {
       if (currentReader.hasNext()) {
         currentChunk = getNextChunk();
         return true;
-      } else if (fileIndex < locations.length) {
+      } else if (fileIndex < locations.size()) {
         moveToNextReader(true);
         return currentReader != null;
       }
@@ -668,7 +680,7 @@ public abstract class CelebornInputStream extends 
InputStream {
 
     @Override
     public int totalPartitionsToRead() {
-      return locations.length;
+      return locations.size();
     }
 
     @Override
diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
 
b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
index 0d474ee72..3158aa12f 100644
--- 
a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
+++ 
b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
@@ -74,6 +74,7 @@ public class WorkerPartitionReader implements PartitionReader 
{
       CelebornConf conf,
       String shuffleKey,
       PartitionLocation location,
+      PbStreamHandler pbStreamHandler,
       TransportClientFactory clientFactory,
       int startMapIndex,
       int endMapIndex,
@@ -116,18 +117,22 @@ public class WorkerPartitionReader implements 
PartitionReader {
       throw ie;
     }
 
-    TransportMessage openStreamMsg =
-        new TransportMessage(
-            MessageType.OPEN_STREAM,
-            PbOpenStream.newBuilder()
-                .setShuffleKey(shuffleKey)
-                .setFileName(location.getFileName())
-                .setStartIndex(startMapIndex)
-                .setEndIndex(endMapIndex)
-                .build()
-                .toByteArray());
-    ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(), 
fetchTimeoutMs);
-    streamHandler = 
TransportMessage.fromByteBuffer(response).getParsedPayload();
+    if (pbStreamHandler == null) {
+      TransportMessage openStreamMsg =
+          new TransportMessage(
+              MessageType.OPEN_STREAM,
+              PbOpenStream.newBuilder()
+                  .setShuffleKey(shuffleKey)
+                  .setFileName(location.getFileName())
+                  .setStartIndex(startMapIndex)
+                  .setEndIndex(endMapIndex)
+                  .build()
+                  .toByteArray());
+      ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(), 
fetchTimeoutMs);
+      this.streamHandler = 
TransportMessage.fromByteBuffer(response).getParsedPayload();
+    } else {
+      this.streamHandler = pbStreamHandler;
+    }
 
     this.location = location;
     this.clientFactory = clientFactory;
diff --git 
a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java 
b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
index 47642019a..cae634fa7 100644
--- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -36,7 +36,10 @@ import org.slf4j.LoggerFactory;
 import org.apache.celeborn.client.read.CelebornInputStream;
 import org.apache.celeborn.client.read.MetricsCallback;
 import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.exception.CelebornIOException;
+import org.apache.celeborn.common.network.client.TransportClientFactory;
 import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.PbStreamHandler;
 import org.apache.celeborn.common.rpc.RpcEndpointRef;
 import org.apache.celeborn.common.util.ExceptionMaker;
 import org.apache.celeborn.common.util.JavaUtils;
@@ -115,6 +118,12 @@ public class DummyShuffleClient extends ShuffleClient {
   @Override
   public void cleanup(int shuffleId, int mapId, int attemptId) {}
 
+  @Override
+  public ShuffleClientImpl.ReduceFileGroups updateFileGroup(int shuffleId, int 
partitionId)
+      throws CelebornIOException {
+    return null;
+  }
+
   @Override
   public CelebornInputStream readPartition(
       int shuffleId,
@@ -124,6 +133,9 @@ public class DummyShuffleClient extends ShuffleClient {
       int startMapIndex,
       int endMapIndex,
       ExceptionMaker exceptionMaker,
+      ArrayList<PartitionLocation> locations,
+      ArrayList<PbStreamHandler> streamHandlers,
+      int[] mapAttempts,
       MetricsCallback metricsCallback)
       throws IOException {
     return null;
@@ -170,6 +182,11 @@ public class DummyShuffleClient extends ShuffleClient {
     return true;
   }
 
+  @Override
+  public TransportClientFactory getDataClientFactory() {
+    return null;
+  }
+
   public void initReducePartitionMap(int shuffleId, int numPartitions, int 
workerNum) {
     ConcurrentHashMap<Integer, PartitionLocation> map = 
JavaUtils.newConcurrentHashMap();
     String host = "host";
diff --git 
a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala 
b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
index 88247d9b6..d5631d5b0 100644
--- 
a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
+++ 
b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
@@ -152,11 +152,33 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
     }
 
     // reduce normal empty CelebornInputStream
-    var stream = shuffleClient.readPartition(shuffleId, 1, 1, 0, 
Integer.MAX_VALUE, metricsCallback)
+    var stream = shuffleClient.readPartition(
+      shuffleId,
+      shuffleId,
+      1,
+      1,
+      0,
+      Integer.MAX_VALUE,
+      null,
+      null,
+      null,
+      null,
+      metricsCallback)
     Assert.assertEquals(stream.read(), -1)
 
     // reduce normal null partition for CelebornInputStream
-    stream = shuffleClient.readPartition(shuffleId, 3, 1, 0, 
Integer.MAX_VALUE, metricsCallback)
+    stream = shuffleClient.readPartition(
+      shuffleId,
+      shuffleId,
+      3,
+      1,
+      0,
+      Integer.MAX_VALUE,
+      null,
+      null,
+      null,
+      null,
+      metricsCallback)
     Assert.assertEquals(stream.read(), -1)
   }
 
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index 392edce85..1d684b217 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -28,29 +28,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.common.exception.CelebornIOException;
-import org.apache.celeborn.common.protocol.MessageType;
-import org.apache.celeborn.common.protocol.PbApplicationMeta;
-import org.apache.celeborn.common.protocol.PbApplicationMetaRequest;
-import org.apache.celeborn.common.protocol.PbAuthenticationInitiationRequest;
-import org.apache.celeborn.common.protocol.PbAuthenticationInitiationResponse;
-import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
-import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
-import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
-import org.apache.celeborn.common.protocol.PbGetShuffleId;
-import org.apache.celeborn.common.protocol.PbGetShuffleIdResponse;
-import org.apache.celeborn.common.protocol.PbOpenStream;
-import org.apache.celeborn.common.protocol.PbPushDataHandShake;
-import org.apache.celeborn.common.protocol.PbReadAddCredit;
-import org.apache.celeborn.common.protocol.PbRegionFinish;
-import org.apache.celeborn.common.protocol.PbRegionStart;
-import org.apache.celeborn.common.protocol.PbRegisterApplicationRequest;
-import org.apache.celeborn.common.protocol.PbRegisterApplicationResponse;
-import org.apache.celeborn.common.protocol.PbReportShuffleFetchFailure;
-import org.apache.celeborn.common.protocol.PbReportShuffleFetchFailureResponse;
-import org.apache.celeborn.common.protocol.PbSaslRequest;
-import org.apache.celeborn.common.protocol.PbStreamChunkSlice;
-import org.apache.celeborn.common.protocol.PbStreamHandler;
-import org.apache.celeborn.common.protocol.PbTransportableError;
+import org.apache.celeborn.common.protocol.*;
 
 public class TransportMessage implements Serializable {
   private static final long serialVersionUID = -3259000920699629773L;
@@ -123,6 +101,10 @@ public class TransportMessage implements Serializable {
         return (T) PbApplicationMeta.parseFrom(payload);
       case APPLICATION_META_REQUEST_VALUE:
         return (T) PbApplicationMetaRequest.parseFrom(payload);
+      case BATCH_OPEN_STREAM_VALUE:
+        return (T) PbOpenStreamList.parseFrom(payload);
+      case BATCH_OPEN_STREAM_RESPONSE_VALUE:
+        return (T) PbOpenStreamListResponse.parseFrom(payload);
       default:
         logger.error("Unexpected type {}", type);
     }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
 
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
index 3d5d6a790..0ebfad65a 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
@@ -81,7 +81,8 @@ public enum StatusCode {
   REVIVE_INITIALIZED(47),
   DESTROY_SLOTS_MOCK_FAILURE(48),
   COMMIT_FILES_MOCK_FAILURE(49),
-  PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA(50);
+  PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA(50),
+  OPEN_STREAM_FAILED(51);
 
   private final byte value;
 
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index 2d24431e2..7f190a3bb 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -100,6 +100,8 @@ enum MessageType {
   WORKER_EVENT_RESPONSE = 77;
   APPLICATION_META = 78;
   APPLICATION_META_REQUEST = 79;
+  BATCH_OPEN_STREAM = 80;
+  BATCH_OPEN_STREAM_RESPONSE = 81;
 }
 
 enum StreamType {
@@ -635,6 +637,25 @@ message PbStreamHandler {
   string fullPath = 4;
 }
 
+message PbOpenStreamList {
+  string shuffleKey = 1;
+  repeated string fileName = 2;
+  repeated int32 startIndex = 3;
+  repeated int32 endIndex = 4;
+  repeated int32 initialCredit = 5;
+  repeated bool readLocalShuffle = 6;
+}
+
+message PbStreamHandlerOpt {
+  int32 status = 1;
+  PbStreamHandler streamHandler = 2;
+  string errorMsg = 3;
+}
+
+message PbOpenStreamListResponse {
+  repeated PbStreamHandlerOpt streamHandlerOpt = 2;
+}
+
 message PbPushDataHandShake {
   PbPartitionLocation.Mode mode = 1;
   string shuffleKey = 2;
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index ec1ac8b03..114f50f2b 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -20,9 +20,12 @@ package org.apache.celeborn.service.deploy.worker
 import java.io.{FileNotFoundException, IOException}
 import java.nio.charset.StandardCharsets
 import java.util
+import java.util.concurrent.{Future => JFuture}
 import java.util.concurrent.atomic.AtomicBoolean
 import java.util.function.Consumer
 
+import scala.collection.JavaConverters._
+
 import com.google.common.base.Throwables
 import com.google.protobuf.GeneratedMessageV3
 import io.netty.util.concurrent.{Future, GenericFutureListener}
@@ -37,8 +40,9 @@ import 
org.apache.celeborn.common.network.client.{RpcResponseCallback, Transport
 import org.apache.celeborn.common.network.protocol._
 import org.apache.celeborn.common.network.server.BaseMessageHandler
 import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
-import org.apache.celeborn.common.protocol.{MessageType, PbBufferStreamEnd, 
PbChunkFetchRequest, PbOpenStream, PbReadAddCredit, PbStreamHandler, StreamType}
-import org.apache.celeborn.common.util.{ExceptionUtils, Utils}
+import org.apache.celeborn.common.protocol.{MessageType, PbBufferStreamEnd, 
PbChunkFetchRequest, PbOpenStream, PbOpenStreamList, PbOpenStreamListResponse, 
PbReadAddCredit, PbStreamHandler, PbStreamHandlerOpt, StreamType}
+import org.apache.celeborn.common.protocol.message.StatusCode
+import org.apache.celeborn.common.util.{ExceptionUtils, ThreadUtils, Utils}
 import org.apache.celeborn.service.deploy.worker.storage.{ChunkStreamManager, 
CreditStreamManager, PartitionFilesSorter, StorageManager}
 
 class FetchHandler(
@@ -136,6 +140,33 @@ class FetchHandler(
           isLegacy = false,
           openStream.getReadLocalShuffle,
           callback)
+      case openStreamList: PbOpenStreamList =>
+        val shuffleKey = openStreamList.getShuffleKey()
+        val files = openStreamList.getFileNameList
+        val startIndices = openStreamList.getStartIndexList
+        val endIndices = openStreamList.getEndIndexList
+        val readLocalFlags = openStreamList.getReadLocalShuffleList
+        val pbOpenStreamListResponse = PbOpenStreamListResponse.newBuilder()
+
+        0 until files.size() foreach { idx =>
+          val pbStreamHandlerOpt = handleReduceOpenStreamInternal(
+            client,
+            shuffleKey,
+            files.get(idx),
+            startIndices.get(idx),
+            endIndices.get(idx),
+            readLocalFlags.get(idx))
+          if (pbStreamHandlerOpt.getStatus != StatusCode.SUCCESS.getValue) {
+            workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT)
+          }
+          pbOpenStreamListResponse.addStreamHandlerOpt(pbStreamHandlerOpt)
+        }
+
+        client.getChannel.writeAndFlush(new RpcResponse(
+          rpcRequest.requestId,
+          new NioManagedBuffer(new TransportMessage(
+            MessageType.BATCH_OPEN_STREAM_RESPONSE,
+            pbOpenStreamListResponse.build().toByteArray).toByteBuffer)))
       case bufferStreamEnd: PbBufferStreamEnd =>
         handleEndStreamFromClient(
           client,
@@ -197,6 +228,86 @@ class FetchHandler(
 
   }
 
+  private def handleReduceOpenStreamInternal(
+      client: TransportClient,
+      shuffleKey: String,
+      fileName: String,
+      startIndex: Int,
+      endIndex: Int,
+      readLocalShuffle: Boolean = false): PbStreamHandlerOpt = {
+    try {
+      logDebug(s"Received open stream request $shuffleKey $fileName 
$startIndex " +
+        s"$endIndex get file name $fileName from client channel " +
+        s"${NettyUtils.getRemoteAddress(client.getChannel)}")
+
+      var fileInfo = getRawDiskFileInfo(shuffleKey, fileName)
+      val streamId = chunkStreamManager.nextStreamId()
+      // we must get sorted fileInfo for the following cases.
+      // 1. when the current request is a non-range openStream, but the 
original unsorted file
+      //    has been deleted by another range's openStream request.
+      // 2. when the current request is a range openStream request.
+      if ((endIndex != Int.MaxValue) || (endIndex == Int.MaxValue && 
!fileInfo.addStream(
+          streamId))) {
+        fileInfo = partitionsSorter.getSortedFileInfo(
+          shuffleKey,
+          fileName,
+          fileInfo,
+          startIndex,
+          endIndex)
+      }
+      val meta = fileInfo.getFileMeta.asInstanceOf[ReduceFileMeta]
+      val streamHandler =
+        if (readLocalShuffle) {
+          chunkStreamManager.registerStream(
+            streamId,
+            shuffleKey,
+            fileName)
+          makeStreamHandler(
+            streamId,
+            meta.getNumChunks,
+            meta.getChunkOffsets,
+            fileInfo.getFilePath)
+        } else if (fileInfo.isHdfs) {
+          chunkStreamManager.registerStream(
+            streamId,
+            shuffleKey,
+            fileName)
+          makeStreamHandler(streamId, numChunks = 0)
+        } else {
+          chunkStreamManager.registerStream(
+            streamId,
+            shuffleKey,
+            new FileManagedBuffers(fileInfo, transportConf),
+            fileName,
+            storageManager.getFetchTimeMetric(fileInfo.getFile))
+          if (meta.getNumChunks == 0)
+            logDebug(s"StreamId $streamId, fileName $fileName, mapRange " +
+              s"[$startIndex-$endIndex] is empty. Received from client channel 
" +
+              s"${NettyUtils.getRemoteAddress(client.getChannel)}")
+          else logDebug(
+            s"StreamId $streamId, fileName $fileName, numChunks 
${meta.getNumChunks}, " +
+              s"mapRange [$startIndex-$endIndex]. Received from client channel 
" +
+              s"${NettyUtils.getRemoteAddress(client.getChannel)}")
+          makeStreamHandler(
+            streamId,
+            meta.getNumChunks)
+        }
+      workerSource.incCounter(WorkerSource.OPEN_STREAM_SUCCESS_COUNT)
+      PbStreamHandlerOpt.newBuilder().setStreamHandler(streamHandler)
+        .setStatus(StatusCode.SUCCESS.getValue)
+        .build()
+    } catch {
+      case e: IOException =>
+        val msg =
+          s"Read file: $fileName with shuffleKey: $shuffleKey error from 
${NettyUtils.getRemoteAddress(
+            client.getChannel)}, Exception: ${e.getMessage}"
+        
PbStreamHandlerOpt.newBuilder().setStatus(StatusCode.OPEN_STREAM_FAILED.getValue)
+          .setErrorMsg(msg).build()
+    } finally {
+      workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, shuffleKey)
+    }
+  }
+
   private def handleOpenStreamInternal(
       client: TransportClient,
       shuffleKey: String,
@@ -211,74 +322,31 @@ class FetchHandler(
     workerSource.recordAppActiveConnection(client, shuffleKey)
     workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME, shuffleKey)
     try {
-      var fileInfo = getRawDiskFileInfo(shuffleKey, fileName)
+      val fileInfo = getRawDiskFileInfo(shuffleKey, fileName)
       fileInfo.getFileMeta match {
         case _: ReduceFileMeta =>
-          logDebug(s"Received open stream request $shuffleKey $fileName 
$startIndex " +
-            s"$endIndex get file name $fileName from client channel " +
-            s"${NettyUtils.getRemoteAddress(client.getChannel)}")
-
-          val streamId = chunkStreamManager.nextStreamId()
-          // we must get sorted fileInfo for the following cases.
-          // 1. when the current request is a non-range openStream, but the 
original unsorted file
-          //    has been deleted by another range's openStream request.
-          // 2. when the current request is a range openStream request.
-          if ((endIndex != Int.MaxValue) || (endIndex == Int.MaxValue && 
!fileInfo.addStream(
-              streamId))) {
-            fileInfo = partitionsSorter.getSortedFileInfo(
-              shuffleKey,
-              fileName,
-              fileInfo,
-              startIndex,
-              endIndex)
-          }
-          val meta = fileInfo.getFileMeta.asInstanceOf[ReduceFileMeta]
-          if (readLocalShuffle) {
-            chunkStreamManager.registerStream(
-              streamId,
-              shuffleKey,
-              fileName)
-            replyStreamHandler(
+          val pbStreamHandlerOpt =
+            handleReduceOpenStreamInternal(
               client,
-              rpcRequestId,
-              streamId,
-              meta.getNumChunks,
-              isLegacy,
-              meta.getChunkOffsets,
-              fileInfo.getFilePath)
-          } else if (fileInfo.isHdfs) {
-            chunkStreamManager.registerStream(
-              streamId,
-              shuffleKey,
-              fileName)
-            replyStreamHandler(client, rpcRequestId, streamId, numChunks = 0, 
isLegacy)
-          } else {
-            chunkStreamManager.registerStream(
-              streamId,
               shuffleKey,
-              new FileManagedBuffers(fileInfo, transportConf),
               fileName,
-              storageManager.getFetchTimeMetric(fileInfo.getFile))
-            if (meta.getNumChunks == 0)
-              logDebug(s"StreamId $streamId, fileName $fileName, mapRange " +
-                s"[$startIndex-$endIndex] is empty. Received from client 
channel " +
-                s"${NettyUtils.getRemoteAddress(client.getChannel)}")
-            else logDebug(
-              s"StreamId $streamId, fileName $fileName, numChunks 
${meta.getNumChunks}, " +
-                s"mapRange [$startIndex-$endIndex]. Received from client 
channel " +
-                s"${NettyUtils.getRemoteAddress(client.getChannel)}")
-            replyStreamHandler(
-              client,
-              rpcRequestId,
-              streamId,
-              meta.getNumChunks,
-              isLegacy)
+              startIndex,
+              endIndex,
+              readLocalShuffle)
+
+          if (pbStreamHandlerOpt.getStatus != StatusCode.SUCCESS.getValue) {
+            throw new CelebornIOException(pbStreamHandlerOpt.getErrorMsg)
           }
+          replyStreamHandler(client, rpcRequestId, 
pbStreamHandlerOpt.getStreamHandler, isLegacy)
         case _: MapFileMeta =>
           val creditStreamHandler =
             new Consumer[java.lang.Long] {
               override def accept(streamId: java.lang.Long): Unit = {
-                replyStreamHandler(client, rpcRequestId, streamId, 0, isLegacy)
+                val pbStreamHandler = PbStreamHandler.newBuilder
+                  .setStreamId(streamId)
+                  .setNumChunks(0)
+                  .build()
+                replyStreamHandler(client, rpcRequestId, pbStreamHandler, 
isLegacy)
               }
             }
 
@@ -301,28 +369,34 @@ class FetchHandler(
     }
   }
 
-  private def replyStreamHandler(
-      client: TransportClient,
-      requestId: Long,
+  private def makeStreamHandler(
       streamId: Long,
       numChunks: Int,
-      isLegacy: Boolean,
       offsets: util.List[java.lang.Long] = null,
-      filepath: String = ""): Unit = {
+      filepath: String = ""): PbStreamHandler = {
+    val pbStreamHandlerBuilder = 
PbStreamHandler.newBuilder.setStreamId(streamId).setNumChunks(
+      numChunks)
+    if (offsets != null) {
+      pbStreamHandlerBuilder.addAllChunkOffsets(offsets)
+    }
+    if (filepath.nonEmpty) {
+      pbStreamHandlerBuilder.setFullPath(filepath)
+    }
+    pbStreamHandlerBuilder.build()
+  }
+
+  private def replyStreamHandler(
+      client: TransportClient,
+      requestId: Long,
+      pbStreamHandler: PbStreamHandler,
+      isLegacy: Boolean): Unit = {
     if (isLegacy) {
       client.getChannel.writeAndFlush(new RpcResponse(
         requestId,
-        new NioManagedBuffer(new StreamHandle(streamId, 
numChunks).toByteBuffer)))
+        new NioManagedBuffer(new StreamHandle(
+          pbStreamHandler.getStreamId,
+          pbStreamHandler.getNumChunks).toByteBuffer)))
     } else {
-      val pbStreamHandlerBuilder = 
PbStreamHandler.newBuilder.setStreamId(streamId).setNumChunks(
-        numChunks)
-      if (offsets != null) {
-        pbStreamHandlerBuilder.addAllChunkOffsets(offsets)
-      }
-      if (filepath.nonEmpty) {
-        pbStreamHandlerBuilder.setFullPath(filepath)
-      }
-      val pbStreamHandler = pbStreamHandlerBuilder.build()
       client.getChannel.writeAndFlush(new RpcResponse(
         requestId,
         new NioManagedBuffer(new TransportMessage(
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
index 3009a6445..d62a78515 100644
--- 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
@@ -104,7 +104,18 @@ trait ReadWriteTestBase extends AnyFunSuite
       override def incBytesRead(bytesWritten: Long): Unit = {}
       override def incReadTime(time: Long): Unit = {}
     }
-    val inputStream = shuffleClient.readPartition(1, 0, 0, 0, 
Integer.MAX_VALUE, metricsCallback)
+    val inputStream = shuffleClient.readPartition(
+      1,
+      1,
+      0,
+      0,
+      0,
+      Integer.MAX_VALUE,
+      null,
+      null,
+      null,
+      null,
+      metricsCallback)
     val outputStream = new ByteArrayOutputStream()
 
     var b = inputStream.read()

Reply via email to