This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new ea39a9372 [CELEBORN-760] Convert OpenStream and StreamHandler to Pb
ea39a9372 is described below
commit ea39a9372aec1995c7bc38dcadc38a41127bc50b
Author: mingji <[email protected]>
AuthorDate: Sat Aug 5 13:58:08 2023 +0800
[CELEBORN-760] Convert OpenStream and StreamHandler to Pb
### What changes were proposed in this pull request?
Merge OpenStream and StreamHandler to transport messages to enhance
celeborn's compatibility.
### Why are the changes needed?
1. Improve flexibility to change RPC.
2. Compatible with 0.2 client.
### Does this PR introduce _any_ user-facing change?
NO.
### How was this patch tested?
UT and cluster.
Closes #1750 from FMX/CELEBORN-760.
Lead-authored-by: mingji <[email protected]>
Co-authored-by: Ethan Feng <[email protected]>
Co-authored-by: Keyong Zhou <[email protected]>
Co-authored-by: Keyong Zhou <[email protected]>
Co-authored-by: Keyong Zhou <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
---
.../celeborn/client/read/DfsPartitionReader.java | 21 ++-
.../client/read/WorkerPartitionReader.java | 36 ++--
.../celeborn/common/network/protocol/Message.java | 1 +
.../common/network/protocol/OpenStream.java | 6 +-
.../network/protocol/OpenStreamWithCredit.java | 1 +
.../common/network/protocol/StreamHandle.java | 3 +-
.../common/network/protocol/TransportMessage.java | 47 ++++++
common/src/main/proto/TransportMessages.proto | 18 ++
.../celeborn/tests/flink/HeartbeatTest.scala | 1 -
.../service/deploy/worker/FetchHandler.scala | 183 ++++++++++++++-------
.../deploy/worker/storage/FileWriterSuiteJ.java | 35 ++--
11 files changed, 253 insertions(+), 99 deletions(-)
diff --git
a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
index 316f11ef1..637aada7f 100644
---
a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
+++
b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
@@ -37,9 +37,10 @@ import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
-import org.apache.celeborn.common.network.protocol.Message;
-import org.apache.celeborn.common.network.protocol.OpenStream;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.util.ShuffleBlockInfoUtils;
import org.apache.celeborn.common.util.Utils;
@@ -77,10 +78,18 @@ public class DfsPartitionReader implements PartitionReader {
try {
TransportClient client =
clientFactory.createClient(location.getHost(),
location.getFetchPort());
- OpenStream openBlocks =
- new OpenStream(shuffleKey, location.getFileName(), startMapIndex,
endMapIndex);
- ByteBuffer response = client.sendRpcSync(openBlocks.toByteBuffer(),
fetchTimeoutMs);
- Message.decode(response);
+ TransportMessage openStream =
+ new TransportMessage(
+ MessageType.OPEN_STREAM,
+ PbOpenStream.newBuilder()
+ .setShuffleKey(shuffleKey)
+ .setFileName(location.getFileName())
+ .setStartIndex(startMapIndex)
+ .setEndIndex(endMapIndex)
+ .build()
+ .toByteArray());
+ ByteBuffer response = client.sendRpcSync(openStream.toByteBuffer(),
fetchTimeoutMs);
+ TransportMessage.fromByteBuffer(response).getParsedPayload();
// Parse this message to ensure sort is done.
} catch (IOException | InterruptedException e) {
throw new IOException(
diff --git
a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
index e323fe6e1..c02c7dd4f 100644
---
a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
+++
b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
@@ -35,17 +35,18 @@ import
org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
import org.apache.celeborn.common.network.client.ChunkReceivedCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
-import org.apache.celeborn.common.network.protocol.Message;
-import org.apache.celeborn.common.network.protocol.OpenStream;
-import org.apache.celeborn.common.network.protocol.StreamHandle;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.PbOpenStream;
+import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.common.util.ExceptionUtils;
public class WorkerPartitionReader implements PartitionReader {
private final Logger logger =
LoggerFactory.getLogger(WorkerPartitionReader.class);
private PartitionLocation location;
private final TransportClientFactory clientFactory;
- private StreamHandle streamHandle;
+ private PbStreamHandler streamHandle;
private int returnedChunks;
private int chunkIndex;
@@ -105,10 +106,19 @@ public class WorkerPartitionReader implements
PartitionReader {
logger.error("PartitionReader thread interrupted while creating
client.");
throw ie;
}
- OpenStream openBlocks =
- new OpenStream(shuffleKey, location.getFileName(), startMapIndex,
endMapIndex);
- ByteBuffer response = client.sendRpcSync(openBlocks.toByteBuffer(),
fetchTimeoutMs);
- streamHandle = (StreamHandle) Message.decode(response);
+
+ TransportMessage openStreamMsg =
+ new TransportMessage(
+ MessageType.OPEN_STREAM,
+ PbOpenStream.newBuilder()
+ .setShuffleKey(shuffleKey)
+ .setFileName(location.getFileName())
+ .setStartIndex(startMapIndex)
+ .setEndIndex(endMapIndex)
+ .build()
+ .toByteArray());
+ ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(),
fetchTimeoutMs);
+ streamHandle =
TransportMessage.fromByteBuffer(response).getParsedPayload();
this.location = location;
this.clientFactory = clientFactory;
@@ -118,12 +128,12 @@ public class WorkerPartitionReader implements
PartitionReader {
}
public boolean hasNext() {
- return returnedChunks < streamHandle.numChunks;
+ return returnedChunks < streamHandle.getNumChunks();
}
public ByteBuf next() throws IOException, InterruptedException {
checkException();
- if (chunkIndex < streamHandle.numChunks) {
+ if (chunkIndex < streamHandle.getNumChunks()) {
fetchChunks();
}
ByteBuf chunk = null;
@@ -159,7 +169,7 @@ public class WorkerPartitionReader implements
PartitionReader {
final int inFlight = chunkIndex - returnedChunks;
if (inFlight < fetchMaxReqsInFlight) {
final int toFetch =
- Math.min(fetchMaxReqsInFlight - inFlight + 1, streamHandle.numChunks
- chunkIndex);
+ Math.min(fetchMaxReqsInFlight - inFlight + 1,
streamHandle.getNumChunks() - chunkIndex);
for (int i = 0; i < toFetch; i++) {
if (testFetch && fetchChunkRetryCnt < fetchChunkMaxRetry - 1 &&
chunkIndex == 3) {
callback.onFailure(chunkIndex, new CelebornIOException("Test fetch
chunk failure"));
@@ -167,12 +177,12 @@ public class WorkerPartitionReader implements
PartitionReader {
try {
TransportClient client =
clientFactory.createClient(location.getHost(),
location.getFetchPort());
- client.fetchChunk(streamHandle.streamId, chunkIndex,
fetchTimeoutMs, callback);
+ client.fetchChunk(streamHandle.getStreamId(), chunkIndex,
fetchTimeoutMs, callback);
chunkIndex++;
} catch (IOException e) {
logger.error(
"fetchChunk for streamId: {}, chunkIndex: {} failed.",
- streamHandle.streamId,
+ streamHandle.getStreamId(),
chunkIndex,
e);
ExceptionUtils.wrapAndThrowIOException(e);
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
index 989c40e50..b2aaf7354 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
@@ -230,6 +230,7 @@ public abstract class Message implements Encodable {
case TRANSPORTABLE_ERROR:
return TransportableError.decode(in);
+
case BUFFER_STREAM_END:
return BufferStreamEnd.decode(in);
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStream.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStream.java
index ff1ae1565..da5b3dad3 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStream.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStream.java
@@ -23,7 +23,11 @@ import java.util.Arrays;
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
-/** Request to read a set of blocks. Returns {@link StreamHandle}. */
+/**
+ * Request to read a set of blocks. Returns {@link StreamHandle}. Use
PbOpenStream instead of this
+ * one.
+ */
+@Deprecated
public final class OpenStream extends RequestMessage {
public byte[] shuffleKey;
public byte[] fileName;
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStreamWithCredit.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStreamWithCredit.java
index d59d86dfc..ea3f12a6b 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStreamWithCredit.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStreamWithCredit.java
@@ -24,6 +24,7 @@ import java.nio.charset.StandardCharsets;
import io.netty.buffer.ByteBuf;
/** Buffer stream used in Map partition scenario. */
+@Deprecated
public final class OpenStreamWithCredit extends RequestMessage {
public final byte[] shuffleKey;
public final byte[] fileName;
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamHandle.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamHandle.java
index 4d50bec66..c2428728e 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamHandle.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamHandle.java
@@ -22,8 +22,9 @@ import io.netty.buffer.ByteBuf;
/**
* Identifier for a fixed number of chunks to read from a stream created by an
"open blocks"
- * message.
+ * message. Use PbStreamHandler instead of this.
*/
+@Deprecated
public final class StreamHandle extends RequestMessage {
public final long streamId;
public final int numChunks;
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index a139fc08e..d72bfec1b 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -17,12 +17,25 @@
package org.apache.celeborn.common.network.protocol;
+import static
org.apache.celeborn.common.protocol.MessageType.OPEN_STREAM_VALUE;
+import static
org.apache.celeborn.common.protocol.MessageType.STREAM_HANDLER_VALUE;
+
import java.io.Serializable;
+import java.nio.ByteBuffer;
+
+import com.google.protobuf.GeneratedMessageV3;
+import com.google.protobuf.InvalidProtocolBufferException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.protocol.MessageType;
+import org.apache.celeborn.common.protocol.PbOpenStream;
+import org.apache.celeborn.common.protocol.PbStreamHandler;
public class TransportMessage implements Serializable {
private static final long serialVersionUID = -3259000920699629773L;
+ private static Logger logger =
LoggerFactory.getLogger(TransportMessage.class);
@Deprecated private final MessageType type;
private final int messageTypeValue;
private final byte[] payload;
@@ -44,4 +57,38 @@ public class TransportMessage implements Serializable {
public byte[] getPayload() {
return payload;
}
+
+ public <T extends GeneratedMessageV3> T getParsedPayload() throws
InvalidProtocolBufferException {
+ switch (messageTypeValue) {
+ case OPEN_STREAM_VALUE:
+ return (T) PbOpenStream.parseFrom(payload);
+ case STREAM_HANDLER_VALUE:
+ return (T) PbStreamHandler.parseFrom(payload);
+ default:
+ logger.error("Unexpected type {}", type);
+ }
+ return null;
+ }
+
+ public ByteBuffer toByteBuffer() {
+ int totalBufferSize = payload.length + 4 + 4;
+ ByteBuffer buffer = ByteBuffer.allocate(totalBufferSize);
+ buffer.putInt(messageTypeValue);
+ buffer.putInt(payload.length);
+ buffer.put(payload);
+ buffer.flip();
+ return buffer;
+ }
+
+ public static TransportMessage fromByteBuffer(ByteBuffer buffer) throws
CelebornIOException {
+ int messageTypeValue = buffer.getInt();
+ if (MessageType.forNumber(messageTypeValue) == null) {
+ throw new CelebornIOException("Decode failed, fallback to legacy
messages.");
+ }
+ int payloadLen = buffer.getInt();
+ byte[] payload = new byte[payloadLen];
+ buffer.get(payload);
+ MessageType msgType = MessageType.forNumber(messageTypeValue);
+ return new TransportMessage(msgType, payload);
+ }
}
diff --git a/common/src/main/proto/TransportMessages.proto
b/common/src/main/proto/TransportMessages.proto
index 0454ceff1..172dc1a91 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -69,6 +69,8 @@ enum MessageType {
REGISTER_MAP_PARTITION_TASK = 48;
HEARTBEAT_FROM_APPLICATION_RESPONSE = 49;
CHECK_FOR_HDFS_EXPIRED_DIRS_TIMEOUT = 50;
+ OPEN_STREAM = 51;
+ STREAM_HANDLER = 52;
}
message PbStorageInfo {
@@ -446,3 +448,19 @@ message PbSnapshotMetaInfo {
map<string, int64> lostWorkers = 12;
repeated PbWorkerInfo shutdownWorkers = 13;
}
+
+message PbOpenStream {
+ string shuffleKey = 1;
+ string fileName = 2;
+ int32 startIndex = 3;
+ int32 endIndex = 4;
+ int32 initialCredit = 5;
+ bool localRead = 6;
+}
+
+message PbStreamHandler {
+ int64 streamId = 1 ;
+ int32 numChunks = 2;
+ repeated int64 chunkOffsets = 3 ;
+ string fullPath = 4;
+}
\ No newline at end of file
diff --git
a/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
b/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
index 625debb4c..a373c6fd8 100644
---
a/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
+++
b/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
@@ -24,7 +24,6 @@ import org.apache.celeborn.common.identity.UserIdentifier
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl
import org.apache.celeborn.service.deploy.{HeartbeatFeature,
MiniClusterFeature}
-import org.apache.celeborn.service.deploy.worker.memory.MemoryManager;
class HeartbeatTest extends AnyFunSuite with Logging with MiniClusterFeature
with HeartbeatFeature
with BeforeAndAfterAll with BeforeAndAfterEach {
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index 5426bbb77..45e700695 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -17,9 +17,9 @@
package org.apache.celeborn.service.deploy.worker
+import java.{lang, util}
import java.io.{FileNotFoundException, IOException}
import java.nio.charset.StandardCharsets
-import java.util
import java.util.concurrent.atomic.AtomicBoolean
import java.util.function.Consumer
@@ -36,7 +36,7 @@ import org.apache.celeborn.common.network.protocol._
import org.apache.celeborn.common.network.protocol.Message.Type
import org.apache.celeborn.common.network.server.BaseMessageHandler
import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
-import org.apache.celeborn.common.protocol.PartitionType
+import org.apache.celeborn.common.protocol.{MessageType, PartitionType,
PbOpenStream, PbStreamHandler}
import org.apache.celeborn.common.util.{ExceptionUtils, Utils}
import org.apache.celeborn.service.deploy.worker.storage.{ChunkStreamManager,
CreditStreamManager, PartitionFilesSorter, StorageManager}
@@ -94,51 +94,110 @@ class FetchHandler(val conf: CelebornConf, val
transportConf: TransportConf)
case r: ChunkFetchRequest =>
handleChunkFetchRequest(client, r)
case r: RpcRequest =>
- val msg = Message.decode(r.body().nioByteBuffer())
- handleOpenStream(client, r, msg)
+ // process PbOpenStream RPC
+ var timerShuffleKey: String = null
+ try {
+ val pbMsg = TransportMessage.fromByteBuffer(r.body().nioByteBuffer())
+ val pbOpenStream = pbMsg.getParsedPayload[PbOpenStream]
+ val (shuffleKey, fileName, startIndex, endIndex, initialCredit) =
+ (
+ pbOpenStream.getShuffleKey,
+ pbOpenStream.getFileName,
+ pbOpenStream.getStartIndex,
+ pbOpenStream.getEndIndex,
+ pbOpenStream.getInitialCredit)
+
+ timerShuffleKey = shuffleKey
+ workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME,
timerShuffleKey)
+ handleOpenStreamInternal(
+ client,
+ shuffleKey,
+ fileName,
+ startIndex,
+ endIndex,
+ initialCredit,
+ r,
+ false)
+ } catch {
+ case _: Exception =>
+ // process legacy OpenStream RPCs
+ logDebug("Open stream with legacy RPCs")
+ try {
+ val decodedMsg = Message.decode(r.body().nioByteBuffer())
+ val (shuffleKey, fileName) =
+ if (decodedMsg.`type`() == Type.OPEN_STREAM) {
+ val openStream = decodedMsg.asInstanceOf[OpenStream]
+ (
+ new String(openStream.shuffleKey, StandardCharsets.UTF_8),
+ new String(openStream.fileName, StandardCharsets.UTF_8))
+ } else {
+ val openStreamWithCredit =
decodedMsg.asInstanceOf[OpenStreamWithCredit]
+ (
+ new String(openStreamWithCredit.shuffleKey,
StandardCharsets.UTF_8),
+ new String(openStreamWithCredit.fileName,
StandardCharsets.UTF_8))
+ }
+ timerShuffleKey = shuffleKey
+ var startIndex = 0
+ var endIndex = 0
+ var initialCredit = 0
+ getRawFileInfo(shuffleKey, fileName).getPartitionType match {
+ case PartitionType.REDUCE =>
+ startIndex =
decodedMsg.asInstanceOf[OpenStream].startMapIndex
+ endIndex = decodedMsg.asInstanceOf[OpenStream].endMapIndex
+ case PartitionType.MAP =>
+ initialCredit =
decodedMsg.asInstanceOf[OpenStreamWithCredit].initialCredit
+ startIndex =
decodedMsg.asInstanceOf[OpenStreamWithCredit].startIndex
+ endIndex =
decodedMsg.asInstanceOf[OpenStreamWithCredit].endIndex
+ case PartitionType.MAPGROUP =>
+ }
+ handleOpenStreamInternal(
+ client,
+ shuffleKey,
+ fileName,
+ startIndex,
+ endIndex,
+ initialCredit,
+ r,
+ true)
+ } catch {
+ case e: IOException =>
+ handleRpcIOException(client, r.requestId, e)
+ }
+ } finally {
+ r.body().release()
+ workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME,
timerShuffleKey)
+ }
case unknown: RequestMessage =>
throw new IllegalArgumentException(s"Unknown message type id:
${unknown.`type`.id}")
}
}
- // here are BackLogAnnouncement,OpenStream and OpenStreamWithCredit RPCs to
handle
- def handleOpenStream(client: TransportClient, request: RpcRequest, msg:
Message): Unit = {
- val (shuffleKey, fileName) =
- if (msg.`type`() == Type.OPEN_STREAM) {
- val openStream = msg.asInstanceOf[OpenStream]
- (
- new String(openStream.shuffleKey, StandardCharsets.UTF_8),
- new String(openStream.fileName, StandardCharsets.UTF_8))
- } else {
- val openStreamWithCredit = msg.asInstanceOf[OpenStreamWithCredit]
- (
- new String(openStreamWithCredit.shuffleKey, StandardCharsets.UTF_8),
- new String(openStreamWithCredit.fileName, StandardCharsets.UTF_8))
- }
- // metrics start
- workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME, shuffleKey)
+ private def handleOpenStreamInternal(
+ client: TransportClient,
+ shuffleKey: String,
+ fileName: String,
+ startIndex: Int,
+ endIndex: Int,
+ initialCredit: Int,
+ request: RpcRequest,
+ isLegacy: Boolean): Unit = {
try {
var fileInfo = getRawFileInfo(shuffleKey, fileName)
- try fileInfo.getPartitionType match {
+ fileInfo.getPartitionType match {
case PartitionType.REDUCE =>
- val startMapIndex = msg.asInstanceOf[OpenStream].startMapIndex
- val endMapIndex = msg.asInstanceOf[OpenStream].endMapIndex
- if (endMapIndex != Integer.MAX_VALUE) {
+ if (endIndex != Integer.MAX_VALUE) {
fileInfo = partitionsSorter.getSortedFileInfo(
shuffleKey,
fileName,
fileInfo,
- startMapIndex,
- endMapIndex)
+ startIndex,
+ endIndex)
}
- logDebug(s"Received chunk fetch request $shuffleKey $fileName
$startMapIndex " +
- s"$endMapIndex get file info $fileInfo from client channel " +
+ logDebug(s"Received chunk fetch request $shuffleKey $fileName
$startIndex " +
+ s"$endIndex get file info $fileInfo from client channel " +
s"${NettyUtils.getRemoteAddress(client.getChannel)}")
if (fileInfo.isHdfs) {
- val streamHandle = new StreamHandle(0, 0)
- client.getChannel.writeAndFlush(new RpcResponse(
- request.requestId,
- new NioManagedBuffer(streamHandle.toByteBuffer)))
+ replyStreamHandler(client, request.requestId, 0, 0, isLegacy)
} else {
val buffers = new FileManagedBuffers(fileInfo, transportConf)
val fetchTimeMetrics =
storageManager.getFetchTimeMetric(fileInfo.getFile)
@@ -146,54 +205,56 @@ class FetchHandler(val conf: CelebornConf, val
transportConf: TransportConf)
shuffleKey,
buffers,
fetchTimeMetrics)
- val streamHandle = new StreamHandle(streamId, fileInfo.numChunks())
if (fileInfo.numChunks() == 0)
logDebug(s"StreamId $streamId, fileName $fileName, mapRange " +
- s"[$startMapIndex-$endMapIndex] is empty. Received from client
channel " +
+ s"[$startIndex-$endIndex] is empty. Received from client
channel " +
s"${NettyUtils.getRemoteAddress(client.getChannel)}")
else logDebug(
s"StreamId $streamId, fileName $fileName, numChunks
${fileInfo.numChunks()}, " +
- s"mapRange [$startMapIndex-$endMapIndex]. Received from client
channel " +
+ s"mapRange [$startIndex-$endIndex]. Received from client
channel " +
s"${NettyUtils.getRemoteAddress(client.getChannel)}")
- client.getChannel.writeAndFlush(new RpcResponse(
- request.requestId,
- new NioManagedBuffer(streamHandle.toByteBuffer)))
+ replyStreamHandler(client, request.requestId, streamId,
fileInfo.numChunks(), isLegacy)
}
case PartitionType.MAP =>
- val initialCredit =
msg.asInstanceOf[OpenStreamWithCredit].initialCredit
- val startIndex = msg.asInstanceOf[OpenStreamWithCredit].startIndex
- val endIndex = msg.asInstanceOf[OpenStreamWithCredit].endIndex
-
- val callback = new Consumer[java.lang.Long] {
- override def accept(streamId: java.lang.Long): Unit = {
- val bufferStreamHandle = new StreamHandle(streamId, 0)
- client.getChannel.writeAndFlush(new RpcResponse(
- request.requestId,
- new NioManagedBuffer(bufferStreamHandle.toByteBuffer)))
+ val creditStreamHandler =
+ new Consumer[java.lang.Long] {
+ override def accept(streamId: java.lang.Long): Unit = {
+ replyStreamHandler(client, request.requestId, streamId, 0,
isLegacy)
+ }
}
- }
creditStreamManager.registerStream(
- callback,
+ creditStreamHandler,
client.getChannel,
initialCredit,
startIndex,
endIndex,
fileInfo)
-
case PartitionType.MAPGROUP =>
- } catch {
- case e: IOException =>
- handleRpcIOException(client, request.requestId, e)
- } finally {
- // metrics end
- workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, shuffleKey)
- request.body().release()
}
} catch {
- case ioe: IOException =>
- workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, shuffleKey)
- handleRpcIOException(client, request.requestId, ioe)
+ case e: IOException =>
+ handleRpcIOException(client, request.requestId, e)
+ }
+ }
+
+ private def replyStreamHandler(
+ client: TransportClient,
+ requestId: Long,
+ streamId: Long,
+ numChunks: Int,
+ isLegacy: Boolean): Unit = {
+ if (isLegacy) {
+ client.getChannel.writeAndFlush(new RpcResponse(
+ requestId,
+ new NioManagedBuffer(new StreamHandle(streamId,
numChunks).toByteBuffer)))
+ } else {
+ client.getChannel.writeAndFlush(new RpcResponse(
+ requestId,
+ new NioManagedBuffer(new TransportMessage(
+ MessageType.STREAM_HANDLER,
+ PbStreamHandler.newBuilder.setStreamId(streamId).setNumChunks(
+ numChunks).build.toByteArray).toByteBuffer)))
}
}
diff --git
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java
index 81decac0c..ac13df7c7 100644
---
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java
+++
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java
@@ -55,6 +55,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.exception.CelebornException;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.meta.FileInfo;
import org.apache.celeborn.common.network.TransportContext;
@@ -62,15 +63,11 @@ import
org.apache.celeborn.common.network.buffer.ManagedBuffer;
import org.apache.celeborn.common.network.client.ChunkReceivedCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
-import org.apache.celeborn.common.network.protocol.Message;
-import org.apache.celeborn.common.network.protocol.OpenStream;
-import org.apache.celeborn.common.network.protocol.StreamHandle;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.server.TransportServer;
import org.apache.celeborn.common.network.util.NettyUtils;
import org.apache.celeborn.common.network.util.TransportConf;
-import org.apache.celeborn.common.protocol.PartitionSplitMode;
-import org.apache.celeborn.common.protocol.PartitionType;
-import org.apache.celeborn.common.protocol.StorageInfo;
+import org.apache.celeborn.common.protocol.*;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.ThreadUtils;
import org.apache.celeborn.common.util.Utils;
@@ -198,19 +195,25 @@ public class FileWriterSuiteJ {
}
public ByteBuffer createOpenMessage() {
- byte[] shuffleKeyBytes = "shuffleKey".getBytes(StandardCharsets.UTF_8);
- byte[] fileNameBytes = "location".getBytes(StandardCharsets.UTF_8);
-
- OpenStream openBlocks = new OpenStream(shuffleKeyBytes, fileNameBytes, 0,
Integer.MAX_VALUE);
-
- return openBlocks.toByteBuffer();
+ TransportMessage message =
+ new TransportMessage(
+ MessageType.OPEN_STREAM,
+ PbOpenStream.newBuilder()
+ .setShuffleKey("shuffleKey")
+ .setFileName("location")
+ .setStartIndex(0)
+ .setEndIndex(Integer.MAX_VALUE)
+ .build()
+ .toByteArray());
+
+ return message.toByteBuffer();
}
- private void setUpConn(TransportClient client) throws IOException {
+ private void setUpConn(TransportClient client) throws IOException,
CelebornException {
ByteBuffer resp = client.sendRpcSync(createOpenMessage(), 10000);
- StreamHandle streamHandle = (StreamHandle) Message.decode(resp);
- streamId = streamHandle.streamId;
- numChunks = streamHandle.numChunks;
+ PbStreamHandler streamHandle =
TransportMessage.fromByteBuffer(resp).getParsedPayload();
+ streamId = streamHandle.getStreamId();
+ numChunks = streamHandle.getNumChunks();
}
private FetchResult fetchChunks(TransportClient client, List<Integer>
chunkIndices)