This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new fc238005b [CELEBORN-1144] Batch OpenStream RPCs
fc238005b is described below
commit fc238005bd8482ea41612aae6aae7e8f16f918f5
Author: zky.zhoukeyong <[email protected]>
AuthorDate: Mon Mar 25 16:25:05 2024 +0800
[CELEBORN-1144] Batch OpenStream RPCs
### What changes were proposed in this pull request?
Batch OpenStream RPCs by Worker to avoid too many RPCs.
### Why are the changes needed?
ditto
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Passes GA and Manual tests.
Closes #2362 from waitinfuture/1144.
Authored-by: zky.zhoukeyong <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../flink/readclient/FlinkShuffleClientImpl.java | 2 +-
.../shuffle/celeborn/CelebornShuffleReader.scala | 170 ++++++++++++----
.../org/apache/celeborn/client/ShuffleClient.java | 15 ++
.../apache/celeborn/client/ShuffleClientImpl.java | 25 ++-
.../celeborn/client/read/CelebornInputStream.java | 64 +++---
.../client/read/WorkerPartitionReader.java | 29 +--
.../apache/celeborn/client/DummyShuffleClient.java | 17 ++
.../celeborn/client/WithShuffleClientSuite.scala | 26 ++-
.../common/network/protocol/TransportMessage.java | 28 +--
.../common/protocol/message/StatusCode.java | 3 +-
common/src/main/proto/TransportMessages.proto | 21 ++
.../service/deploy/worker/FetchHandler.scala | 224 ++++++++++++++-------
.../service/deploy/cluster/ReadWriteTestBase.scala | 13 +-
13 files changed, 448 insertions(+), 189 deletions(-)
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
index 6ced153c4..bb599c93e 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
@@ -171,7 +171,7 @@ public class FlinkShuffleClientImpl extends
ShuffleClientImpl {
}
@Override
- protected ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
+ public ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
throws CelebornIOException {
ReduceFileGroups reduceFileGroups =
reduceFileGroupsMap.computeIfAbsent(shuffleId, (id) -> new
ReduceFileGroups());
diff --git
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index dc3dcb4bb..bb66a3cef 100644
---
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -18,9 +18,12 @@
package org.apache.spark.shuffle.celeborn
import java.io.IOException
+import java.util
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
+import scala.collection.JavaConverters._
+
import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency,
TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerInstance
@@ -33,7 +36,11 @@ import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.{CelebornIOException,
PartitionUnRetryAbleException}
-import org.apache.celeborn.common.util.{ExceptionMaker, ThreadUtils}
+import org.apache.celeborn.common.network.client.TransportClient
+import org.apache.celeborn.common.network.protocol.TransportMessage
+import org.apache.celeborn.common.protocol.{MessageType, PartitionLocation,
PbOpenStreamList, PbOpenStreamListResponse, PbStreamHandler}
+import org.apache.celeborn.common.protocol.message.StatusCode
+import org.apache.celeborn.common.util.{ExceptionMaker, JavaUtils,
ThreadUtils, Utils}
class CelebornShuffleReader[K, C](
handle: CelebornShuffleHandle[K, _, C],
@@ -107,60 +114,139 @@ class CelebornShuffleReader[K, C](
}
}
- val streams = new ConcurrentHashMap[Integer, CelebornInputStream]()
- (startPartition until endPartition).map(partitionId => {
+ val startTime = System.currentTimeMillis()
+ val fetchTimeoutMs = conf.clientFetchTimeoutMs
+ val localFetchEnabled = conf.enableReadLocalShuffleFile
+ val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
+ // startPartition is irrelevant
+ val fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
+ // host-port -> (TransportClient, PartitionLocation Array,
PbOpenStreamList)
+ val workerRequestMap = new util.HashMap[
+ String,
+ (TransportClient, util.ArrayList[PartitionLocation],
PbOpenStreamList.Builder)]()
+
+ var partCnt = 0
+
+ (startPartition until endPartition).foreach { partitionId =>
+ if (fileGroups.partitionGroups.containsKey(partitionId)) {
+ fileGroups.partitionGroups.get(partitionId).asScala.foreach { location
=>
+ partCnt += 1
+ val hostPort = location.hostAndFetchPort
+ if (!workerRequestMap.containsKey(hostPort)) {
+ val client = shuffleClient.getDataClientFactory().createClient(
+ location.getHost,
+ location.getFetchPort)
+ val pbOpenStreamList = PbOpenStreamList.newBuilder()
+ pbOpenStreamList.setShuffleKey(shuffleKey)
+ workerRequestMap.put(
+ hostPort,
+ (client, new util.ArrayList[PartitionLocation],
pbOpenStreamList))
+ }
+ val (_, locArr, pbOpenStreamListBuilder) =
workerRequestMap.get(hostPort)
+
+ locArr.add(location)
+ pbOpenStreamListBuilder.addFileName(location.getFileName)
+ .addStartIndex(startMapIndex)
+ .addEndIndex(endMapIndex)
+ pbOpenStreamListBuilder.addReadLocalShuffle(localFetchEnabled)
+ }
+ }
+ }
+
+ val locationStreamHandlerMap: ConcurrentHashMap[PartitionLocation,
PbStreamHandler] =
+ JavaUtils.newConcurrentHashMap()
+
+ val futures = workerRequestMap.values().asScala.map { entry =>
streamCreatorPool.submit(new Runnable {
override def run(): Unit = {
- if (exceptionRef.get() == null) {
+ val (client, locArr, pbOpenStreamListBuilder) = entry
+ val msg = new TransportMessage(
+ MessageType.BATCH_OPEN_STREAM,
+ pbOpenStreamListBuilder.build().toByteArray)
+ val pbOpenStreamListResponse =
try {
- val inputStream = shuffleClient.readPartition(
- shuffleId,
- handle.shuffleId,
- partitionId,
- context.attemptNumber(),
- startMapIndex,
- endMapIndex,
- if (throwsFetchFailure) exceptionMaker else null,
- metricsCallback)
- streams.put(partitionId, inputStream)
+ val response = client.sendRpcSync(msg.toByteBuffer,
fetchTimeoutMs)
+
TransportMessage.fromByteBuffer(response).getParsedPayload[PbOpenStreamListResponse]
} catch {
- case e: IOException =>
- logError(s"Exception caught when readPartition $partitionId!",
e)
- exceptionRef.compareAndSet(null, e)
- case e: Throwable =>
- logError(s"Non IOException caught when readPartition
$partitionId!", e)
- exceptionRef.compareAndSet(null, new CelebornIOException(e))
+ case _: Exception => null
+ }
+ if (pbOpenStreamListResponse != null) {
+ 0 until locArr.size() foreach { idx =>
+ val streamHandlerOpt =
pbOpenStreamListResponse.getStreamHandlerOptList.get(idx)
+ if (streamHandlerOpt.getStatus == StatusCode.SUCCESS.getValue) {
+ locationStreamHandlerMap.put(locArr.get(idx),
streamHandlerOpt.getStreamHandler)
+ }
}
}
}
})
- })
+ }.toList
+ // wait for all futures to complete
+ futures.foreach(f => f.get())
+ val end = System.currentTimeMillis()
+ logInfo(s"BatchOpenStream for $partCnt cost ${end - startTime}ms")
+
+ def createInputStream(partitionId: Int): CelebornInputStream = {
+ val locations =
+ if (fileGroups.partitionGroups.containsKey(partitionId)) {
+ new util.ArrayList(fileGroups.partitionGroups.get(partitionId))
+ } else new util.ArrayList[PartitionLocation]()
+ val streamHandlers =
+ if (locations != null) {
+ val streamHandlerArr = new
util.ArrayList[PbStreamHandler](locations.size())
+ locations.asScala.foreach { loc =>
+ streamHandlerArr.add(locationStreamHandlerMap.get(loc))
+ }
+ streamHandlerArr
+ } else null
+ if (exceptionRef.get() == null) {
+ try {
+ shuffleClient.readPartition(
+ shuffleId,
+ handle.shuffleId,
+ partitionId,
+ context.attemptNumber(),
+ startMapIndex,
+ endMapIndex,
+ if (throwsFetchFailure) exceptionMaker else null,
+ locations,
+ streamHandlers,
+ fileGroups.mapAttempts,
+ metricsCallback)
+ } catch {
+ case e: IOException =>
+ logError(s"Exception caught when readPartition $partitionId!", e)
+ exceptionRef.compareAndSet(null, e)
+ null
+ case e: Throwable =>
+ logError(s"Non IOException caught when readPartition
$partitionId!", e)
+ exceptionRef.compareAndSet(null, new CelebornIOException(e))
+ null
+ }
+ } else null
+ }
val recordIter = (startPartition until
endPartition).iterator.map(partitionId => {
if (handle.numMappers > 0) {
val startFetchWait = System.nanoTime()
- var inputStream: CelebornInputStream = streams.get(partitionId)
- while (inputStream == null) {
- if (exceptionRef.get() != null) {
- exceptionRef.get() match {
- case ce @ (_: CelebornIOException | _:
PartitionUnRetryAbleException) =>
- if (throwsFetchFailure &&
- shuffleClient.reportShuffleFetchFailure(handle.shuffleId,
shuffleId)) {
- throw new FetchFailedException(
- null,
- handle.shuffleId,
- -1,
- -1,
- partitionId,
- SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId +
"/" + shuffleId,
- ce)
- } else
- throw ce
- case e => throw e
- }
+ val inputStream: CelebornInputStream = createInputStream(partitionId)
+ if (exceptionRef.get() != null) {
+ exceptionRef.get() match {
+ case ce @ (_: CelebornIOException | _:
PartitionUnRetryAbleException) =>
+ if (throwsFetchFailure &&
+ shuffleClient.reportShuffleFetchFailure(handle.shuffleId,
shuffleId)) {
+ throw new FetchFailedException(
+ null,
+ handle.shuffleId,
+ -1,
+ -1,
+ partitionId,
+ SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/"
+ shuffleId,
+ ce)
+ } else
+ throw ce
+ case e => throw e
}
- Thread.sleep(50)
- inputStream = streams.get(partitionId)
}
metricsCallback.incReadTime(
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait))
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index b2e694806..aef173a62 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -18,6 +18,7 @@
package org.apache.celeborn.client;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
@@ -28,8 +29,11 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.read.CelebornInputStream;
import org.apache.celeborn.client.read.MetricsCallback;
import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.identity.UserIdentifier;
+import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
import org.apache.celeborn.common.util.CelebornHadoopUtils;
import org.apache.celeborn.common.util.ExceptionMaker;
@@ -197,6 +201,9 @@ public abstract class ShuffleClient {
// Cleanup states of the map task
public abstract void cleanup(int shuffleId, int mapId, int attemptId);
+ public abstract ShuffleClientImpl.ReduceFileGroups updateFileGroup(int
shuffleId, int partitionId)
+ throws CelebornIOException;
+
// Reduce side read partition which is deduplicated by
mapperId+mapperAttemptNum+batchId, batchId
// is a self-incrementing variable hidden in the implementation when sending
data.
/**
@@ -227,6 +234,9 @@ public abstract class ShuffleClient {
startMapIndex,
endMapIndex,
null,
+ null,
+ null,
+ null,
metricsCallback);
}
@@ -238,6 +248,9 @@ public abstract class ShuffleClient {
int startMapIndex,
int endMapIndex,
ExceptionMaker exceptionMaker,
+ ArrayList<PartitionLocation> locations,
+ ArrayList<PbStreamHandler> streamHandlers,
+ int[] mapAttempts,
MetricsCallback metricsCallback)
throws IOException;
@@ -261,4 +274,6 @@ public abstract class ShuffleClient {
* incorrect shuffle data can be fetched in re-run tasks
*/
public abstract boolean reportShuffleFetchFailure(int appShuffleId, int
shuffleId);
+
+ public abstract TransportClientFactory getDataClientFactory();
}
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index c2fe19184..0494d73e9 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -1646,7 +1646,7 @@ public class ShuffleClientImpl extends ShuffleClient {
}
}
- protected ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
+ public ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
throws CelebornIOException {
if (reduceFileGroupsMap.containsKey(shuffleId)) {
return reduceFileGroupsMap.get(shuffleId);
@@ -1679,16 +1679,28 @@ public class ShuffleClientImpl extends ShuffleClient {
int startMapIndex,
int endMapIndex,
ExceptionMaker exceptionMaker,
+ ArrayList<PartitionLocation> locations,
+ ArrayList<PbStreamHandler> streamHandlers,
+ int[] mapAttempts,
MetricsCallback metricsCallback)
throws IOException {
if (partitionId == Utils$.MODULE$.UNKNOWN_APP_SHUFFLE_ID()) {
logger.warn("Shuffle data is empty for shuffle {}:
UNKNOWN_APP_SHUFFLE_ID.", shuffleId);
return CelebornInputStream.empty();
}
- ReduceFileGroups fileGroups = updateFileGroup(shuffleId, partitionId);
- if (fileGroups.partitionGroups.isEmpty()
- || !fileGroups.partitionGroups.containsKey(partitionId)) {
+ // When `mapAttempts` is not null, it's guaranteed that the code path
comes from
+ // CelebornShuffleReader, which means `updateFileGroup` is already called
and
+ // batch open stream has been tried
+ if (mapAttempts == null) {
+ ReduceFileGroups fileGroups = updateFileGroup(shuffleId, partitionId);
+ mapAttempts = fileGroups.mapAttempts;
+ if (fileGroups.partitionGroups.containsKey(partitionId)) {
+ locations = new ArrayList(fileGroups.partitionGroups.get(partitionId));
+ }
+ }
+
+ if (locations == null || locations.size() == 0) {
logger.warn("Shuffle data is empty for shuffle {} partition {}.",
shuffleId, partitionId);
return CelebornInputStream.empty();
} else {
@@ -1698,8 +1710,9 @@ public class ShuffleClientImpl extends ShuffleClient {
conf,
dataClientFactory,
shuffleKey,
- fileGroups.partitionGroups.get(partitionId).toArray(new
PartitionLocation[0]),
- fileGroups.mapAttempts,
+ locations,
+ streamHandlers,
+ mapAttempts,
attemptNumber,
startMapIndex,
endMapIndex,
diff --git
a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
index a10239bf5..4312272dd 100644
---
a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
+++
b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
@@ -25,6 +25,8 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.LongAdder;
+import scala.Tuple2;
+
import com.google.common.util.concurrent.Uninterruptibles;
import io.netty.buffer.ByteBuf;
import org.roaringbitmap.RoaringBitmap;
@@ -37,10 +39,7 @@ import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.util.TransportConf;
-import org.apache.celeborn.common.protocol.CompressionCodec;
-import org.apache.celeborn.common.protocol.PartitionLocation;
-import org.apache.celeborn.common.protocol.StorageInfo;
-import org.apache.celeborn.common.protocol.TransportModuleConstants;
+import org.apache.celeborn.common.protocol.*;
import org.apache.celeborn.common.unsafe.Platform;
import org.apache.celeborn.common.util.ExceptionMaker;
import org.apache.celeborn.common.util.Utils;
@@ -52,7 +51,8 @@ public abstract class CelebornInputStream extends InputStream
{
CelebornConf conf,
TransportClientFactory clientFactory,
String shuffleKey,
- PartitionLocation[] locations,
+ ArrayList<PartitionLocation> locations,
+ ArrayList<PbStreamHandler> streamHandlers,
int[] attempts,
int attemptNumber,
int startMapIndex,
@@ -65,7 +65,7 @@ public abstract class CelebornInputStream extends InputStream
{
ExceptionMaker exceptionMaker,
MetricsCallback metricsCallback)
throws IOException {
- if (locations == null || locations.length == 0) {
+ if (locations == null || locations.size() == 0) {
return emptyInputStream;
} else {
return new CelebornInputStreamImpl(
@@ -73,6 +73,7 @@ public abstract class CelebornInputStream extends InputStream
{
clientFactory,
shuffleKey,
locations,
+ streamHandlers,
attempts,
attemptNumber,
startMapIndex,
@@ -124,7 +125,8 @@ public abstract class CelebornInputStream extends
InputStream {
private final CelebornConf conf;
private final TransportClientFactory clientFactory;
private final String shuffleKey;
- private PartitionLocation[] locations;
+ private ArrayList<PartitionLocation> locations;
+ private ArrayList<PbStreamHandler> streamHandlers;
private int[] attempts;
private final int attemptNumber;
private final int startMapIndex;
@@ -174,7 +176,8 @@ public abstract class CelebornInputStream extends
InputStream {
CelebornConf conf,
TransportClientFactory clientFactory,
String shuffleKey,
- PartitionLocation[] locations,
+ ArrayList<PartitionLocation> locations,
+ ArrayList<PbStreamHandler> streamHandlers,
int[] attempts,
int attemptNumber,
int startMapIndex,
@@ -190,7 +193,10 @@ public abstract class CelebornInputStream extends
InputStream {
this.conf = conf;
this.clientFactory = clientFactory;
this.shuffleKey = shuffleKey;
- this.locations = (PartitionLocation[]) Utils.randomizeInPlace(locations,
RAND);
+ this.locations = locations;
+ if (streamHandlers != null && streamHandlers.size() == locations.size())
{
+ this.streamHandlers = streamHandlers;
+ }
this.attempts = attempts;
this.attemptNumber = attemptNumber;
this.startMapIndex = startMapIndex;
@@ -242,24 +248,25 @@ public abstract class CelebornInputStream extends
InputStream {
return true;
}
- private PartitionLocation nextReadableLocation() {
- int locationCount = locations.length;
+ private Tuple2<PartitionLocation, PbStreamHandler> nextReadableLocation() {
+ int locationCount = locations.size();
if (fileIndex >= locationCount) {
return null;
}
- PartitionLocation currentLocation = locations[fileIndex];
+ PartitionLocation currentLocation = locations.get(fileIndex);
while (skipLocation(startMapIndex, endMapIndex, currentLocation)) {
skipCount.increment();
fileIndex++;
if (fileIndex == locationCount) {
return null;
}
- currentLocation = locations[fileIndex];
+ currentLocation = locations.get(fileIndex);
}
fetchChunkRetryCnt = 0;
- return currentLocation;
+ return new Tuple2(
+ currentLocation, streamHandlers == null ? null :
streamHandlers.get(fileIndex));
}
private void moveToNextReader(boolean fetchChunk) throws IOException {
@@ -267,11 +274,11 @@ public abstract class CelebornInputStream extends
InputStream {
currentReader.close();
currentReader = null;
}
- PartitionLocation currentLocation = nextReadableLocation();
+ Tuple2<PartitionLocation, PbStreamHandler> currentLocation =
nextReadableLocation();
if (currentLocation == null) {
return;
}
- currentReader = createReaderWithRetry(currentLocation);
+ currentReader = createReaderWithRetry(currentLocation._1,
currentLocation._2);
fileIndex++;
while (!currentReader.hasNext()) {
currentReader.close();
@@ -280,7 +287,7 @@ public abstract class CelebornInputStream extends
InputStream {
if (currentLocation == null) {
return;
}
- currentReader = createReaderWithRetry(currentLocation);
+ currentReader = createReaderWithRetry(currentLocation._1,
currentLocation._2);
fileIndex++;
}
if (fetchChunk) {
@@ -332,14 +339,15 @@ public abstract class CelebornInputStream extends
InputStream {
return connectException || rpcTimeout || fetchChunkTimeout;
}
- private PartitionReader createReaderWithRetry(PartitionLocation location)
throws IOException {
+ private PartitionReader createReaderWithRetry(
+ PartitionLocation location, PbStreamHandler pbStreamHandler) throws
IOException {
Exception lastException = null;
while (fetchChunkRetryCnt < fetchChunkMaxRetry) {
try {
if (isExcluded(location)) {
throw new CelebornIOException("Fetch data from excluded worker! "
+ location);
}
- return createReader(location, fetchChunkRetryCnt,
fetchChunkMaxRetry);
+ return createReader(location, pbStreamHandler, fetchChunkRetryCnt,
fetchChunkMaxRetry);
} catch (Exception e) {
lastException = e;
excludeFailedLocation(location, e);
@@ -404,7 +412,7 @@ public abstract class CelebornInputStream extends
InputStream {
if (fetchChunkRetryCnt % 2 == 0) {
Uninterruptibles.sleepUninterruptibly(retryWaitMs,
TimeUnit.MILLISECONDS);
}
- currentReader =
createReaderWithRetry(currentReader.getLocation().getPeer());
+ currentReader =
createReaderWithRetry(currentReader.getLocation().getPeer(), null);
} else {
logger.warn(
"Fetch chunk failed {}/{} times for location {}",
@@ -413,7 +421,7 @@ public abstract class CelebornInputStream extends
InputStream {
currentReader.getLocation(),
e);
Uninterruptibles.sleepUninterruptibly(retryWaitMs,
TimeUnit.MILLISECONDS);
- currentReader =
createReaderWithRetry(currentReader.getLocation());
+ currentReader =
createReaderWithRetry(currentReader.getLocation(), null);
}
}
}
@@ -422,11 +430,14 @@ public abstract class CelebornInputStream extends
InputStream {
}
private PartitionReader createReader(
- PartitionLocation location, int fetchChunkRetryCnt, int
fetchChunkMaxRetry)
+ PartitionLocation location,
+ PbStreamHandler pbStreamHandler,
+ int fetchChunkRetryCnt,
+ int fetchChunkMaxRetry)
throws IOException, InterruptedException {
if (!location.hasPeer()) {
logger.debug("Partition {} has only one partition replica.", location);
- } else if (attemptNumber % 2 == 1) {
+ } else if (pbStreamHandler == null && attemptNumber % 2 == 1) {
location = location.getPeer();
logger.debug("Read peer {} for attempt {}.", location, attemptNumber);
}
@@ -446,6 +457,7 @@ public abstract class CelebornInputStream extends
InputStream {
conf,
shuffleKey,
location,
+ pbStreamHandler,
clientFactory,
startMapIndex,
endMapIndex,
@@ -513,7 +525,7 @@ public abstract class CelebornInputStream extends
InputStream {
@Override
public synchronized void close() {
if (!closed) {
- int locationsCount = locations.length;
+ int locationsCount = locations.size();
logger.debug(
"AppShuffleId {}, shuffleId {}, partitionId {}, total location
count {}, read {}, skip {}",
appShuffleId,
@@ -556,7 +568,7 @@ public abstract class CelebornInputStream extends
InputStream {
if (currentReader.hasNext()) {
currentChunk = getNextChunk();
return true;
- } else if (fileIndex < locations.length) {
+ } else if (fileIndex < locations.size()) {
moveToNextReader(true);
return currentReader != null;
}
@@ -668,7 +680,7 @@ public abstract class CelebornInputStream extends
InputStream {
@Override
public int totalPartitionsToRead() {
- return locations.length;
+ return locations.size();
}
@Override
diff --git
a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
index 0d474ee72..3158aa12f 100644
---
a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
+++
b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
@@ -74,6 +74,7 @@ public class WorkerPartitionReader implements PartitionReader
{
CelebornConf conf,
String shuffleKey,
PartitionLocation location,
+ PbStreamHandler pbStreamHandler,
TransportClientFactory clientFactory,
int startMapIndex,
int endMapIndex,
@@ -116,18 +117,22 @@ public class WorkerPartitionReader implements
PartitionReader {
throw ie;
}
- TransportMessage openStreamMsg =
- new TransportMessage(
- MessageType.OPEN_STREAM,
- PbOpenStream.newBuilder()
- .setShuffleKey(shuffleKey)
- .setFileName(location.getFileName())
- .setStartIndex(startMapIndex)
- .setEndIndex(endMapIndex)
- .build()
- .toByteArray());
- ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(),
fetchTimeoutMs);
- streamHandler =
TransportMessage.fromByteBuffer(response).getParsedPayload();
+ if (pbStreamHandler == null) {
+ TransportMessage openStreamMsg =
+ new TransportMessage(
+ MessageType.OPEN_STREAM,
+ PbOpenStream.newBuilder()
+ .setShuffleKey(shuffleKey)
+ .setFileName(location.getFileName())
+ .setStartIndex(startMapIndex)
+ .setEndIndex(endMapIndex)
+ .build()
+ .toByteArray());
+ ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(),
fetchTimeoutMs);
+ this.streamHandler =
TransportMessage.fromByteBuffer(response).getParsedPayload();
+ } else {
+ this.streamHandler = pbStreamHandler;
+ }
this.location = location;
this.clientFactory = clientFactory;
diff --git
a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
index 47642019a..cae634fa7 100644
--- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -36,7 +36,10 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.read.CelebornInputStream;
import org.apache.celeborn.client.read.MetricsCallback;
import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.exception.CelebornIOException;
+import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
import org.apache.celeborn.common.util.ExceptionMaker;
import org.apache.celeborn.common.util.JavaUtils;
@@ -115,6 +118,12 @@ public class DummyShuffleClient extends ShuffleClient {
@Override
public void cleanup(int shuffleId, int mapId, int attemptId) {}
+ @Override
+ public ShuffleClientImpl.ReduceFileGroups updateFileGroup(int shuffleId, int
partitionId)
+ throws CelebornIOException {
+ return null;
+ }
+
@Override
public CelebornInputStream readPartition(
int shuffleId,
@@ -124,6 +133,9 @@ public class DummyShuffleClient extends ShuffleClient {
int startMapIndex,
int endMapIndex,
ExceptionMaker exceptionMaker,
+ ArrayList<PartitionLocation> locations,
+ ArrayList<PbStreamHandler> streamHandlers,
+ int[] mapAttempts,
MetricsCallback metricsCallback)
throws IOException {
return null;
@@ -170,6 +182,11 @@ public class DummyShuffleClient extends ShuffleClient {
return true;
}
+ @Override
+ public TransportClientFactory getDataClientFactory() {
+ return null;
+ }
+
public void initReducePartitionMap(int shuffleId, int numPartitions, int
workerNum) {
ConcurrentHashMap<Integer, PartitionLocation> map =
JavaUtils.newConcurrentHashMap();
String host = "host";
diff --git
a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
index 88247d9b6..d5631d5b0 100644
---
a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
+++
b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
@@ -152,11 +152,33 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
}
// reduce normal empty CelebornInputStream
- var stream = shuffleClient.readPartition(shuffleId, 1, 1, 0,
Integer.MAX_VALUE, metricsCallback)
+ var stream = shuffleClient.readPartition(
+ shuffleId,
+ shuffleId,
+ 1,
+ 1,
+ 0,
+ Integer.MAX_VALUE,
+ null,
+ null,
+ null,
+ null,
+ metricsCallback)
Assert.assertEquals(stream.read(), -1)
// reduce normal null partition for CelebornInputStream
- stream = shuffleClient.readPartition(shuffleId, 3, 1, 0,
Integer.MAX_VALUE, metricsCallback)
+ stream = shuffleClient.readPartition(
+ shuffleId,
+ shuffleId,
+ 3,
+ 1,
+ 0,
+ Integer.MAX_VALUE,
+ null,
+ null,
+ null,
+ null,
+ metricsCallback)
Assert.assertEquals(stream.read(), -1)
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index 392edce85..1d684b217 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -28,29 +28,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.celeborn.common.exception.CelebornIOException;
-import org.apache.celeborn.common.protocol.MessageType;
-import org.apache.celeborn.common.protocol.PbApplicationMeta;
-import org.apache.celeborn.common.protocol.PbApplicationMetaRequest;
-import org.apache.celeborn.common.protocol.PbAuthenticationInitiationRequest;
-import org.apache.celeborn.common.protocol.PbAuthenticationInitiationResponse;
-import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
-import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
-import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
-import org.apache.celeborn.common.protocol.PbGetShuffleId;
-import org.apache.celeborn.common.protocol.PbGetShuffleIdResponse;
-import org.apache.celeborn.common.protocol.PbOpenStream;
-import org.apache.celeborn.common.protocol.PbPushDataHandShake;
-import org.apache.celeborn.common.protocol.PbReadAddCredit;
-import org.apache.celeborn.common.protocol.PbRegionFinish;
-import org.apache.celeborn.common.protocol.PbRegionStart;
-import org.apache.celeborn.common.protocol.PbRegisterApplicationRequest;
-import org.apache.celeborn.common.protocol.PbRegisterApplicationResponse;
-import org.apache.celeborn.common.protocol.PbReportShuffleFetchFailure;
-import org.apache.celeborn.common.protocol.PbReportShuffleFetchFailureResponse;
-import org.apache.celeborn.common.protocol.PbSaslRequest;
-import org.apache.celeborn.common.protocol.PbStreamChunkSlice;
-import org.apache.celeborn.common.protocol.PbStreamHandler;
-import org.apache.celeborn.common.protocol.PbTransportableError;
+import org.apache.celeborn.common.protocol.*;
public class TransportMessage implements Serializable {
private static final long serialVersionUID = -3259000920699629773L;
@@ -123,6 +101,10 @@ public class TransportMessage implements Serializable {
return (T) PbApplicationMeta.parseFrom(payload);
case APPLICATION_META_REQUEST_VALUE:
return (T) PbApplicationMetaRequest.parseFrom(payload);
+ case BATCH_OPEN_STREAM_VALUE:
+ return (T) PbOpenStreamList.parseFrom(payload);
+ case BATCH_OPEN_STREAM_RESPONSE_VALUE:
+ return (T) PbOpenStreamListResponse.parseFrom(payload);
default:
logger.error("Unexpected type {}", type);
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
index 3d5d6a790..0ebfad65a 100644
---
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
+++
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
@@ -81,7 +81,8 @@ public enum StatusCode {
REVIVE_INITIALIZED(47),
DESTROY_SLOTS_MOCK_FAILURE(48),
COMMIT_FILES_MOCK_FAILURE(49),
- PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA(50);
+ PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA(50),
+ OPEN_STREAM_FAILED(51);
private final byte value;
diff --git a/common/src/main/proto/TransportMessages.proto
b/common/src/main/proto/TransportMessages.proto
index 2d24431e2..7f190a3bb 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -100,6 +100,8 @@ enum MessageType {
WORKER_EVENT_RESPONSE = 77;
APPLICATION_META = 78;
APPLICATION_META_REQUEST = 79;
+ BATCH_OPEN_STREAM = 80;
+ BATCH_OPEN_STREAM_RESPONSE = 81;
}
enum StreamType {
@@ -635,6 +637,25 @@ message PbStreamHandler {
string fullPath = 4;
}
+message PbOpenStreamList {
+ string shuffleKey = 1;
+ repeated string fileName = 2;
+ repeated int32 startIndex = 3;
+ repeated int32 endIndex = 4;
+ repeated int32 initialCredit = 5;
+ repeated bool readLocalShuffle = 6;
+}
+
+message PbStreamHandlerOpt {
+ int32 status = 1;
+ PbStreamHandler streamHandler = 2;
+ string errorMsg = 3;
+}
+
+message PbOpenStreamListResponse {
+ repeated PbStreamHandlerOpt streamHandlerOpt = 2;
+}
+
message PbPushDataHandShake {
PbPartitionLocation.Mode mode = 1;
string shuffleKey = 2;
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index ec1ac8b03..114f50f2b 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -20,9 +20,12 @@ package org.apache.celeborn.service.deploy.worker
import java.io.{FileNotFoundException, IOException}
import java.nio.charset.StandardCharsets
import java.util
+import java.util.concurrent.{Future => JFuture}
import java.util.concurrent.atomic.AtomicBoolean
import java.util.function.Consumer
+import scala.collection.JavaConverters._
+
import com.google.common.base.Throwables
import com.google.protobuf.GeneratedMessageV3
import io.netty.util.concurrent.{Future, GenericFutureListener}
@@ -37,8 +40,9 @@ import
org.apache.celeborn.common.network.client.{RpcResponseCallback, Transport
import org.apache.celeborn.common.network.protocol._
import org.apache.celeborn.common.network.server.BaseMessageHandler
import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
-import org.apache.celeborn.common.protocol.{MessageType, PbBufferStreamEnd,
PbChunkFetchRequest, PbOpenStream, PbReadAddCredit, PbStreamHandler, StreamType}
-import org.apache.celeborn.common.util.{ExceptionUtils, Utils}
+import org.apache.celeborn.common.protocol.{MessageType, PbBufferStreamEnd,
PbChunkFetchRequest, PbOpenStream, PbOpenStreamList, PbOpenStreamListResponse,
PbReadAddCredit, PbStreamHandler, PbStreamHandlerOpt, StreamType}
+import org.apache.celeborn.common.protocol.message.StatusCode
+import org.apache.celeborn.common.util.{ExceptionUtils, ThreadUtils, Utils}
import org.apache.celeborn.service.deploy.worker.storage.{ChunkStreamManager,
CreditStreamManager, PartitionFilesSorter, StorageManager}
class FetchHandler(
@@ -136,6 +140,33 @@ class FetchHandler(
isLegacy = false,
openStream.getReadLocalShuffle,
callback)
+ case openStreamList: PbOpenStreamList =>
+ val shuffleKey = openStreamList.getShuffleKey()
+ val files = openStreamList.getFileNameList
+ val startIndices = openStreamList.getStartIndexList
+ val endIndices = openStreamList.getEndIndexList
+ val readLocalFlags = openStreamList.getReadLocalShuffleList
+ val pbOpenStreamListResponse = PbOpenStreamListResponse.newBuilder()
+
+ 0 until files.size() foreach { idx =>
+ val pbStreamHandlerOpt = handleReduceOpenStreamInternal(
+ client,
+ shuffleKey,
+ files.get(idx),
+ startIndices.get(idx),
+ endIndices.get(idx),
+ readLocalFlags.get(idx))
+ if (pbStreamHandlerOpt.getStatus != StatusCode.SUCCESS.getValue) {
+ workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT)
+ }
+ pbOpenStreamListResponse.addStreamHandlerOpt(pbStreamHandlerOpt)
+ }
+
+ client.getChannel.writeAndFlush(new RpcResponse(
+ rpcRequest.requestId,
+ new NioManagedBuffer(new TransportMessage(
+ MessageType.BATCH_OPEN_STREAM_RESPONSE,
+ pbOpenStreamListResponse.build().toByteArray).toByteBuffer)))
case bufferStreamEnd: PbBufferStreamEnd =>
handleEndStreamFromClient(
client,
@@ -197,6 +228,86 @@ class FetchHandler(
}
+ private def handleReduceOpenStreamInternal(
+ client: TransportClient,
+ shuffleKey: String,
+ fileName: String,
+ startIndex: Int,
+ endIndex: Int,
+ readLocalShuffle: Boolean = false): PbStreamHandlerOpt = {
+ try {
+ logDebug(s"Received open stream request $shuffleKey $fileName
$startIndex " +
+ s"$endIndex get file name $fileName from client channel " +
+ s"${NettyUtils.getRemoteAddress(client.getChannel)}")
+
+ var fileInfo = getRawDiskFileInfo(shuffleKey, fileName)
+ val streamId = chunkStreamManager.nextStreamId()
+ // we must get sorted fileInfo for the following cases.
+ // 1. when the current request is a non-range openStream, but the
original unsorted file
+ // has been deleted by another range's openStream request.
+ // 2. when the current request is a range openStream request.
+ if ((endIndex != Int.MaxValue) || (endIndex == Int.MaxValue &&
!fileInfo.addStream(
+ streamId))) {
+ fileInfo = partitionsSorter.getSortedFileInfo(
+ shuffleKey,
+ fileName,
+ fileInfo,
+ startIndex,
+ endIndex)
+ }
+ val meta = fileInfo.getFileMeta.asInstanceOf[ReduceFileMeta]
+ val streamHandler =
+ if (readLocalShuffle) {
+ chunkStreamManager.registerStream(
+ streamId,
+ shuffleKey,
+ fileName)
+ makeStreamHandler(
+ streamId,
+ meta.getNumChunks,
+ meta.getChunkOffsets,
+ fileInfo.getFilePath)
+ } else if (fileInfo.isHdfs) {
+ chunkStreamManager.registerStream(
+ streamId,
+ shuffleKey,
+ fileName)
+ makeStreamHandler(streamId, numChunks = 0)
+ } else {
+ chunkStreamManager.registerStream(
+ streamId,
+ shuffleKey,
+ new FileManagedBuffers(fileInfo, transportConf),
+ fileName,
+ storageManager.getFetchTimeMetric(fileInfo.getFile))
+ if (meta.getNumChunks == 0)
+ logDebug(s"StreamId $streamId, fileName $fileName, mapRange " +
+ s"[$startIndex-$endIndex] is empty. Received from client channel
" +
+ s"${NettyUtils.getRemoteAddress(client.getChannel)}")
+ else logDebug(
+ s"StreamId $streamId, fileName $fileName, numChunks
${meta.getNumChunks}, " +
+ s"mapRange [$startIndex-$endIndex]. Received from client channel
" +
+ s"${NettyUtils.getRemoteAddress(client.getChannel)}")
+ makeStreamHandler(
+ streamId,
+ meta.getNumChunks)
+ }
+ workerSource.incCounter(WorkerSource.OPEN_STREAM_SUCCESS_COUNT)
+ PbStreamHandlerOpt.newBuilder().setStreamHandler(streamHandler)
+ .setStatus(StatusCode.SUCCESS.getValue)
+ .build()
+ } catch {
+ case e: IOException =>
+ val msg =
+ s"Read file: $fileName with shuffleKey: $shuffleKey error from
${NettyUtils.getRemoteAddress(
+ client.getChannel)}, Exception: ${e.getMessage}"
+
PbStreamHandlerOpt.newBuilder().setStatus(StatusCode.OPEN_STREAM_FAILED.getValue)
+ .setErrorMsg(msg).build()
+ } finally {
+ workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, shuffleKey)
+ }
+ }
+
private def handleOpenStreamInternal(
client: TransportClient,
shuffleKey: String,
@@ -211,74 +322,31 @@ class FetchHandler(
workerSource.recordAppActiveConnection(client, shuffleKey)
workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME, shuffleKey)
try {
- var fileInfo = getRawDiskFileInfo(shuffleKey, fileName)
+ val fileInfo = getRawDiskFileInfo(shuffleKey, fileName)
fileInfo.getFileMeta match {
case _: ReduceFileMeta =>
- logDebug(s"Received open stream request $shuffleKey $fileName
$startIndex " +
- s"$endIndex get file name $fileName from client channel " +
- s"${NettyUtils.getRemoteAddress(client.getChannel)}")
-
- val streamId = chunkStreamManager.nextStreamId()
- // we must get sorted fileInfo for the following cases.
- // 1. when the current request is a non-range openStream, but the
original unsorted file
- // has been deleted by another range's openStream request.
- // 2. when the current request is a range openStream request.
- if ((endIndex != Int.MaxValue) || (endIndex == Int.MaxValue &&
!fileInfo.addStream(
- streamId))) {
- fileInfo = partitionsSorter.getSortedFileInfo(
- shuffleKey,
- fileName,
- fileInfo,
- startIndex,
- endIndex)
- }
- val meta = fileInfo.getFileMeta.asInstanceOf[ReduceFileMeta]
- if (readLocalShuffle) {
- chunkStreamManager.registerStream(
- streamId,
- shuffleKey,
- fileName)
- replyStreamHandler(
+ val pbStreamHandlerOpt =
+ handleReduceOpenStreamInternal(
client,
- rpcRequestId,
- streamId,
- meta.getNumChunks,
- isLegacy,
- meta.getChunkOffsets,
- fileInfo.getFilePath)
- } else if (fileInfo.isHdfs) {
- chunkStreamManager.registerStream(
- streamId,
- shuffleKey,
- fileName)
- replyStreamHandler(client, rpcRequestId, streamId, numChunks = 0,
isLegacy)
- } else {
- chunkStreamManager.registerStream(
- streamId,
shuffleKey,
- new FileManagedBuffers(fileInfo, transportConf),
fileName,
- storageManager.getFetchTimeMetric(fileInfo.getFile))
- if (meta.getNumChunks == 0)
- logDebug(s"StreamId $streamId, fileName $fileName, mapRange " +
- s"[$startIndex-$endIndex] is empty. Received from client
channel " +
- s"${NettyUtils.getRemoteAddress(client.getChannel)}")
- else logDebug(
- s"StreamId $streamId, fileName $fileName, numChunks
${meta.getNumChunks}, " +
- s"mapRange [$startIndex-$endIndex]. Received from client
channel " +
- s"${NettyUtils.getRemoteAddress(client.getChannel)}")
- replyStreamHandler(
- client,
- rpcRequestId,
- streamId,
- meta.getNumChunks,
- isLegacy)
+ startIndex,
+ endIndex,
+ readLocalShuffle)
+
+ if (pbStreamHandlerOpt.getStatus != StatusCode.SUCCESS.getValue) {
+ throw new CelebornIOException(pbStreamHandlerOpt.getErrorMsg)
}
+ replyStreamHandler(client, rpcRequestId,
pbStreamHandlerOpt.getStreamHandler, isLegacy)
case _: MapFileMeta =>
val creditStreamHandler =
new Consumer[java.lang.Long] {
override def accept(streamId: java.lang.Long): Unit = {
- replyStreamHandler(client, rpcRequestId, streamId, 0, isLegacy)
+ val pbStreamHandler = PbStreamHandler.newBuilder
+ .setStreamId(streamId)
+ .setNumChunks(0)
+ .build()
+ replyStreamHandler(client, rpcRequestId, pbStreamHandler,
isLegacy)
}
}
@@ -301,28 +369,34 @@ class FetchHandler(
}
}
- private def replyStreamHandler(
- client: TransportClient,
- requestId: Long,
+ private def makeStreamHandler(
streamId: Long,
numChunks: Int,
- isLegacy: Boolean,
offsets: util.List[java.lang.Long] = null,
- filepath: String = ""): Unit = {
+ filepath: String = ""): PbStreamHandler = {
+ val pbStreamHandlerBuilder =
PbStreamHandler.newBuilder.setStreamId(streamId).setNumChunks(
+ numChunks)
+ if (offsets != null) {
+ pbStreamHandlerBuilder.addAllChunkOffsets(offsets)
+ }
+ if (filepath.nonEmpty) {
+ pbStreamHandlerBuilder.setFullPath(filepath)
+ }
+ pbStreamHandlerBuilder.build()
+ }
+
+ private def replyStreamHandler(
+ client: TransportClient,
+ requestId: Long,
+ pbStreamHandler: PbStreamHandler,
+ isLegacy: Boolean): Unit = {
if (isLegacy) {
client.getChannel.writeAndFlush(new RpcResponse(
requestId,
- new NioManagedBuffer(new StreamHandle(streamId,
numChunks).toByteBuffer)))
+ new NioManagedBuffer(new StreamHandle(
+ pbStreamHandler.getStreamId,
+ pbStreamHandler.getNumChunks).toByteBuffer)))
} else {
- val pbStreamHandlerBuilder =
PbStreamHandler.newBuilder.setStreamId(streamId).setNumChunks(
- numChunks)
- if (offsets != null) {
- pbStreamHandlerBuilder.addAllChunkOffsets(offsets)
- }
- if (filepath.nonEmpty) {
- pbStreamHandlerBuilder.setFullPath(filepath)
- }
- val pbStreamHandler = pbStreamHandlerBuilder.build()
client.getChannel.writeAndFlush(new RpcResponse(
requestId,
new NioManagedBuffer(new TransportMessage(
diff --git
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
index 3009a6445..d62a78515 100644
---
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
+++
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
@@ -104,7 +104,18 @@ trait ReadWriteTestBase extends AnyFunSuite
override def incBytesRead(bytesWritten: Long): Unit = {}
override def incReadTime(time: Long): Unit = {}
}
- val inputStream = shuffleClient.readPartition(1, 0, 0, 0,
Integer.MAX_VALUE, metricsCallback)
+ val inputStream = shuffleClient.readPartition(
+ 1,
+ 1,
+ 0,
+ 0,
+ 0,
+ Integer.MAX_VALUE,
+ null,
+ null,
+ null,
+ null,
+ metricsCallback)
val outputStream = new ByteArrayOutputStream()
var b = inputStream.read()