This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch branch-0.3
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/branch-0.3 by this push:
new 00251aca9 [CELEBORN-770][FLINK] Convert BacklogAnnouncement,
BufferStreamEnd, ReadAddCredit to PB
00251aca9 is described below
commit 00251aca999d2cbcfdbedc7af509d280e954d42a
Author: SteNicholas <[email protected]>
AuthorDate: Mon Sep 25 10:44:48 2023 +0800
[CELEBORN-770][FLINK] Convert BacklogAnnouncement, BufferStreamEnd,
ReadAddCredit to PB
`BacklogAnnouncement`, `BufferStreamEnd`, and `ReadAddCredit` should merge
to transport messages to enhance celeborn's compatibility.
1. Improves celeborn's transport flexibility to change RPC.
2. Makes Compatible with 0.2 client.
No.
- `TransportFrameDecoderWithBufferSupplierSuiteJ`
Closes #1905 from SteNicholas/CELEBORN-770.
Authored-by: SteNicholas <[email protected]>
Signed-off-by: Shuang <[email protected]>
(cherry picked from commit 2407cae43ab44b6d7a7394736e7c12cbbd51ebb5)
Signed-off-by: Shuang <[email protected]>
---
.../plugin/flink/RemoteBufferStreamReader.java | 9 +-
.../plugin/flink/network/ReadClientHandler.java | 32 +++--
.../flink/readclient/CelebornBufferStream.java | 46 ++++---
...nsportFrameDecoderWithBufferSupplierSuiteJ.java | 26 +++-
.../common/network/client/TransportClient.java | 16 +++
.../network/protocol/BacklogAnnouncement.java | 6 +
.../common/network/protocol/BufferStreamEnd.java | 6 +
.../common/network/protocol/ReadAddCredit.java | 1 +
.../common/network/protocol/TransportMessage.java | 12 ++
common/src/main/proto/TransportMessages.proto | 23 +++-
.../worker/storage/MapDataPartitionReader.java | 20 ++-
.../service/deploy/worker/FetchHandler.scala | 153 +++++++++++----------
12 files changed, 245 insertions(+), 105 deletions(-)
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
index e960495bb..51dadf9f1 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
@@ -25,10 +25,10 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
-import org.apache.celeborn.common.network.protocol.ReadAddCredit;
import org.apache.celeborn.common.network.protocol.RequestMessage;
import org.apache.celeborn.common.network.protocol.TransportableError;
import org.apache.celeborn.common.network.util.NettyUtils;
+import org.apache.celeborn.common.protocol.PbReadAddCredit;
import org.apache.celeborn.plugin.flink.buffer.CreditListener;
import org.apache.celeborn.plugin.flink.buffer.TransferBufferPool;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
@@ -115,8 +115,11 @@ public class RemoteBufferStreamReader extends
CreditListener {
public void notifyAvailableCredits(int numCredits) {
if (!closed) {
- ReadAddCredit addCredit = new ReadAddCredit(bufferStream.getStreamId(),
numCredits);
- bufferStream.addCredit(addCredit);
+ bufferStream.addCredit(
+ PbReadAddCredit.newBuilder()
+ .setStreamId(bufferStream.getStreamId())
+ .setCredit(numCredits)
+ .build());
}
}
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
index 9340334a9..5c100002c 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
@@ -17,6 +17,10 @@
package org.apache.celeborn.plugin.flink.network;
+import static
org.apache.celeborn.common.protocol.MessageType.BACKLOG_ANNOUNCEMENT_VALUE;
+import static
org.apache.celeborn.common.protocol.MessageType.BUFFER_STREAM_END_VALUE;
+
+import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
@@ -28,6 +32,7 @@ import
org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.protocol.TransportableError;
import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.util.JavaUtils;
@@ -66,32 +71,43 @@ public class ReadClientHandler extends BaseMessageHandler {
@Override
public void receive(TransportClient client, RequestMessage msg) {
- long streamId = 0;
switch (msg.type()) {
case READ_DATA:
ReadData readData = (ReadData) msg;
- streamId = readData.getStreamId();
- processMessageInternal(streamId, readData);
+ processMessageInternal(readData.getStreamId(), readData);
break;
case BACKLOG_ANNOUNCEMENT:
BacklogAnnouncement backlogAnnouncement = (BacklogAnnouncement) msg;
- streamId = backlogAnnouncement.getStreamId();
- processMessageInternal(streamId, backlogAnnouncement);
+ processMessageInternal(backlogAnnouncement.getStreamId(),
backlogAnnouncement);
break;
case TRANSPORTABLE_ERROR:
TransportableError transportableError = ((TransportableError) msg);
- streamId = transportableError.getStreamId();
logger.warn(
"Received TransportableError from worker {} with content {}",
client.getSocketAddress().toString(),
transportableError.getErrorMessage());
- processMessageInternal(streamId, transportableError);
+ processMessageInternal(transportableError.getStreamId(),
transportableError);
break;
case BUFFER_STREAM_END:
BufferStreamEnd streamEnd = (BufferStreamEnd) msg;
- logger.debug("Received streamend for {}", streamEnd.getStreamId());
processMessageInternal(streamEnd.getStreamId(), streamEnd);
break;
+ case RPC_REQUEST:
+ try {
+ TransportMessage transportMessage =
+ TransportMessage.fromByteBuffer(msg.body().nioByteBuffer());
+ switch (transportMessage.getMessageTypeValue()) {
+ case BACKLOG_ANNOUNCEMENT_VALUE:
+ receive(client,
BacklogAnnouncement.fromProto(transportMessage.getParsedPayload()));
+ break;
+ case BUFFER_STREAM_END_VALUE:
+ receive(client,
BufferStreamEnd.fromProto(transportMessage.getParsedPayload()));
+ break;
+ }
+ } catch (IOException e) {
+ logger.warn("Failed to process RpcRequest message {}. ", msg, e);
+ }
+ break;
case ONE_WAY_MESSAGE:
// ignore it.
break;
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
index 4fc9d7384..fcd9e0458 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
@@ -29,11 +29,15 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.network.client.TransportClient;
-import org.apache.celeborn.common.network.protocol.*;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import org.apache.celeborn.common.network.protocol.TransportableError;
import org.apache.celeborn.common.network.util.NettyUtils;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.protocol.PbOpenStream;
+import org.apache.celeborn.common.protocol.PbReadAddCredit;
import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.plugin.flink.network.FlinkTransportClientFactory;
@@ -82,21 +86,25 @@ public class CelebornBufferStream {
moveToNextPartitionIfPossible(0);
}
- public void addCredit(ReadAddCredit addCredit) {
- this.client
- .getChannel()
- .writeAndFlush(addCredit)
- .addListener(
- future -> {
- if (future.isSuccess()) {
- // Send ReadAddCredit do not expect response.
- } else {
- logger.warn(
- "Send ReadAddCredit to {} failed, detail {}",
- this.client.getSocketAddress().toString(),
- future.cause());
- }
- });
+ public void addCredit(PbReadAddCredit pbReadAddCredit) {
+ this.client.sendRpc(
+ new TransportMessage(MessageType.READ_ADD_CREDIT,
pbReadAddCredit.toByteArray())
+ .toByteBuffer(),
+ new RpcResponseCallback() {
+
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ // Send PbReadAddCredit do not expect response.
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ logger.warn(
+ "Send PbReadAddCredit to {} failed, detail {}",
+ NettyUtils.getRemoteAddress(client.getChannel()),
+ e.getCause());
+ }
+ });
}
public static CelebornBufferStream empty() {
@@ -127,7 +135,11 @@ public class CelebornBufferStream {
private void closeStream(long streamId) {
if (client != null && client.isActive()) {
- client.getChannel().writeAndFlush(new BufferStreamEnd(streamId));
+ client.sendRpc(
+ new TransportMessage(
+ MessageType.BUFFER_STREAM_END,
+
PbBufferStreamEnd.newBuilder().setStreamId(streamId).build().toByteArray())
+ .toByteBuffer());
}
}
diff --git
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
index 073972b44..64696abec 100644
---
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
+++
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
@@ -17,6 +17,8 @@
package org.apache.celeborn.plugin.flink.network;
+import static
org.apache.celeborn.common.network.client.TransportClient.requestId;
+
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
@@ -31,9 +33,13 @@ import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;
-import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
+import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
import org.apache.celeborn.common.network.protocol.Message;
import org.apache.celeborn.common.network.protocol.ReadData;
+import org.apache.celeborn.common.network.protocol.RpcRequest;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import org.apache.celeborn.common.protocol.MessageType;
+import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
import org.apache.celeborn.common.util.JavaUtils;
public class TransportFrameDecoderWithBufferSupplierSuiteJ {
@@ -57,10 +63,10 @@ public class TransportFrameDecoderWithBufferSupplierSuiteJ {
new TransportFrameDecoderWithBufferSupplier(supplier);
ChannelHandlerContext context = Mockito.mock(ChannelHandlerContext.class);
- BacklogAnnouncement announcement = new BacklogAnnouncement(0, 0);
+ RpcRequest announcement = createBacklogAnnouncement(0, 0);
ReadData unUsedReadData = new ReadData(1, generateData(1024));
ReadData readData = new ReadData(2, generateData(1024));
- BacklogAnnouncement announcement1 = new BacklogAnnouncement(0, 0);
+ RpcRequest announcement1 = createBacklogAnnouncement(0, 0);
ReadData unUsedReadData1 = new ReadData(1, generateData(1024));
ReadData readData1 = new ReadData(2, generateData(8));
@@ -102,6 +108,20 @@ public class TransportFrameDecoderWithBufferSupplierSuiteJ
{
Assert.assertEquals(buffers.size(), 6);
}
+ public RpcRequest createBacklogAnnouncement(long streamId, int backlog) {
+ return new RpcRequest(
+ requestId(),
+ new NioManagedBuffer(
+ new TransportMessage(
+ MessageType.BACKLOG_ANNOUNCEMENT,
+ PbBacklogAnnouncement.newBuilder()
+ .setStreamId(streamId)
+ .setBacklog(backlog)
+ .build()
+ .toByteArray())
+ .toByteBuffer()));
+ }
+
public ByteBuf encodeMessage(Message in, ByteBuf byteBuf) throws IOException
{
byteBuf.writeInt(in.encodedLength());
in.type().encode(byteBuf);
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
index 2c9eca4cb..697ca2d26 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
@@ -170,6 +170,22 @@ public class TransportClient implements Closeable {
return requestId;
}
+ /**
+ * Sends an opaque message to the RpcHandler on the server-side.
+ *
+ * @param message The message to send.
+ * @return The RPC's id.
+ */
+ public long sendRpc(ByteBuffer message) {
+ if (logger.isTraceEnabled()) {
+ logger.trace("Sending RPC to {}", NettyUtils.getRemoteAddress(channel));
+ }
+
+ long requestId = requestId();
+ channel.writeAndFlush(new RpcRequest(requestId, new
NioManagedBuffer(message)));
+ return requestId;
+ }
+
public ChannelFuture pushData(
PushData pushData, long pushDataTimeout, RpcResponseCallback callback) {
return pushData(pushData, pushDataTimeout, callback, null);
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/BacklogAnnouncement.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/BacklogAnnouncement.java
index 45f02f5d8..daccb7bf2 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/BacklogAnnouncement.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/BacklogAnnouncement.java
@@ -21,6 +21,8 @@ import static
org.apache.celeborn.common.network.protocol.Message.Type.BACKLOG_A
import io.netty.buffer.ByteBuf;
+import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
+
// This RPC is sent to flink plugin to tell flink client to be ready for
buffers.
public class BacklogAnnouncement extends RequestMessage {
private long streamId;
@@ -60,4 +62,8 @@ public class BacklogAnnouncement extends RequestMessage {
public int getBacklog() {
return backlog;
}
+
+ public static BacklogAnnouncement fromProto(PbBacklogAnnouncement pb) {
+ return new BacklogAnnouncement(pb.getStreamId(), pb.getBacklog());
+ }
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/BufferStreamEnd.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/BufferStreamEnd.java
index d85e380d1..8b86fa547 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/BufferStreamEnd.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/BufferStreamEnd.java
@@ -19,6 +19,8 @@ package org.apache.celeborn.common.network.protocol;
import io.netty.buffer.ByteBuf;
+import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
+
public class BufferStreamEnd extends RequestMessage {
private long streamId;
@@ -49,4 +51,8 @@ public class BufferStreamEnd extends RequestMessage {
public long getStreamId() {
return streamId;
}
+
+ public static BufferStreamEnd fromProto(PbBufferStreamEnd pb) {
+ return new BufferStreamEnd(pb.getStreamId());
+ }
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java
index ca34a5c17..27fa54288 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java
@@ -20,6 +20,7 @@ import java.util.Objects;
import io.netty.buffer.ByteBuf;
+@Deprecated
public class ReadAddCredit extends RequestMessage {
private long streamId;
private int credit;
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index d72bfec1b..769eef0d4 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -17,7 +17,10 @@
package org.apache.celeborn.common.network.protocol;
+import static
org.apache.celeborn.common.protocol.MessageType.BACKLOG_ANNOUNCEMENT_VALUE;
+import static
org.apache.celeborn.common.protocol.MessageType.BUFFER_STREAM_END_VALUE;
import static
org.apache.celeborn.common.protocol.MessageType.OPEN_STREAM_VALUE;
+import static
org.apache.celeborn.common.protocol.MessageType.READ_ADD_CREDIT_VALUE;
import static
org.apache.celeborn.common.protocol.MessageType.STREAM_HANDLER_VALUE;
import java.io.Serializable;
@@ -30,7 +33,10 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.protocol.MessageType;
+import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
+import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.protocol.PbOpenStream;
+import org.apache.celeborn.common.protocol.PbReadAddCredit;
import org.apache.celeborn.common.protocol.PbStreamHandler;
public class TransportMessage implements Serializable {
@@ -64,6 +70,12 @@ public class TransportMessage implements Serializable {
return (T) PbOpenStream.parseFrom(payload);
case STREAM_HANDLER_VALUE:
return (T) PbStreamHandler.parseFrom(payload);
+ case BACKLOG_ANNOUNCEMENT_VALUE:
+ return (T) PbBacklogAnnouncement.parseFrom(payload);
+ case BUFFER_STREAM_END_VALUE:
+ return (T) PbBufferStreamEnd.parseFrom(payload);
+ case READ_ADD_CREDIT_VALUE:
+ return (T) PbReadAddCredit.parseFrom(payload);
default:
logger.error("Unexpected type {}", type);
}
diff --git a/common/src/main/proto/TransportMessages.proto
b/common/src/main/proto/TransportMessages.proto
index 25cbeff50..e31725edc 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -76,6 +76,9 @@ enum MessageType {
CHECK_WORKERS_AVAILABLE = 53;
CHECK_WORKERS_AVAILABLE_RESPONSE = 54;
REMOVE_WORKERS_UNAVAILABLE_INFO = 55;
+ BACKLOG_ANNOUNCEMENT = 59;
+ BUFFER_STREAM_END = 60;
+ READ_ADD_CREDIT = 61;
}
message PbStorageInfo {
@@ -505,8 +508,22 @@ message PbOpenStream {
}
message PbStreamHandler {
- int64 streamId = 1 ;
+ int64 streamId = 1;
int32 numChunks = 2;
- repeated int64 chunkOffsets = 3 ;
+ repeated int64 chunkOffsets = 3;
string fullPath = 4;
-}
\ No newline at end of file
+}
+
+message PbBacklogAnnouncement {
+ int64 streamId = 1;
+ int32 backlog = 2;
+}
+
+message PbBufferStreamEnd {
+ int64 streamId = 1;
+}
+
+message PbReadAddCredit {
+ int64 streamId = 1;
+ int32 credit = 2;
+}
diff --git
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapDataPartitionReader.java
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapDataPartitionReader.java
index b8f996fe7..cecd47900 100644
---
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapDataPartitionReader.java
+++
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapDataPartitionReader.java
@@ -17,6 +17,8 @@
package org.apache.celeborn.service.deploy.worker.storage;
+import static
org.apache.celeborn.common.network.client.TransportClient.requestId;
+
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
@@ -36,11 +38,15 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.common.exception.FileCorruptedException;
import org.apache.celeborn.common.meta.FileInfo;
+import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
-import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
import org.apache.celeborn.common.network.protocol.ReadData;
+import org.apache.celeborn.common.network.protocol.RpcRequest;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.protocol.TransportableError;
import org.apache.celeborn.common.network.util.NettyUtils;
+import org.apache.celeborn.common.protocol.MessageType;
+import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.util.ExceptionUtils;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.service.deploy.worker.memory.BufferQueue;
@@ -442,7 +448,17 @@ public class MapDataPartitionReader implements
Comparable<MapDataPartitionReader
// old client can't support BufferStreamEnd, so for new client it
tells client that this
// stream is finished.
if (fileInfo.isPartitionSplitEnabled() && !errorNotified)
- associatedChannel.writeAndFlush(new BufferStreamEnd(streamId));
+ associatedChannel.writeAndFlush(
+ new RpcRequest(
+ requestId(),
+ new NioManagedBuffer(
+ new TransportMessage(
+ MessageType.BUFFER_STREAM_END,
+ PbBufferStreamEnd.newBuilder()
+ .setStreamId(streamId)
+ .build()
+ .toByteArray())
+ .toByteBuffer())));
if (!buffersToSend.isEmpty()) {
numInUseBuffers.addAndGet(-1 * buffersToSend.size());
buffersToSend.forEach(RecyclableBuffer::recycle);
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index 019c8871e..f7e0e51a6 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -26,6 +26,7 @@ import java.util.function.Consumer
import scala.collection.JavaConverters.asScalaBufferConverter
import com.google.common.base.Throwables
+import com.google.protobuf.GeneratedMessageV3
import io.netty.util.concurrent.{Future, GenericFutureListener}
import org.apache.celeborn.common.CelebornConf
@@ -38,7 +39,7 @@ import org.apache.celeborn.common.network.protocol._
import org.apache.celeborn.common.network.protocol.Message.Type
import org.apache.celeborn.common.network.server.BaseMessageHandler
import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
-import org.apache.celeborn.common.protocol.{MessageType, PartitionType,
PbOpenStream, PbStreamHandler}
+import org.apache.celeborn.common.protocol.{MessageType, PartitionType,
PbBufferStreamEnd, PbOpenStream, PbReadAddCredit, PbStreamHandler}
import org.apache.celeborn.common.util.{ExceptionUtils, Utils}
import org.apache.celeborn.service.deploy.worker.storage.{ChunkStreamManager,
CreditStreamManager, PartitionFilesSorter, StorageManager}
@@ -90,72 +91,30 @@ class FetchHandler(val conf: CelebornConf, val
transportConf: TransportConf)
override def receive(client: TransportClient, msg: RequestMessage): Unit = {
msg match {
case r: BufferStreamEnd =>
- handleEndStreamFromClient(r)
+ handleEndStreamFromClient(r.getStreamId)
case r: ReadAddCredit =>
- handleReadAddCredit(r)
+ handleReadAddCredit(r.getCredit, r.getStreamId)
case r: ChunkFetchRequest =>
handleChunkFetchRequest(client, r)
case r: RpcRequest =>
- // process PbOpenStream RPC
var streamShuffleKey: String = null
- var streamFileName: String = null
try {
- val pbMsg = TransportMessage.fromByteBuffer(r.body().nioByteBuffer())
- val pbOpenStream = pbMsg.getParsedPayload[PbOpenStream]
- val (shuffleKey, fileName, startIndex, endIndex, initialCredit,
readLocalShuffle) =
- (
- pbOpenStream.getShuffleKey,
- pbOpenStream.getFileName,
- pbOpenStream.getStartIndex,
- pbOpenStream.getEndIndex,
- pbOpenStream.getInitialCredit,
- pbOpenStream.getReadLocalShuffle)
- streamShuffleKey = shuffleKey
- streamFileName = fileName
- workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME,
streamShuffleKey)
- handleOpenStreamInternal(
- client,
- shuffleKey,
- fileName,
- startIndex,
- endIndex,
- initialCredit,
- r,
- false,
- readLocalShuffle)
- } catch {
- case _: Exception =>
- // process legacy OpenStream RPCs
- logDebug("Open stream with legacy RPCs")
- try {
- val decodedMsg = Message.decode(r.body().nioByteBuffer())
- val (shuffleKey, fileName) =
- if (decodedMsg.`type`() == Type.OPEN_STREAM) {
- val openStream = decodedMsg.asInstanceOf[OpenStream]
- (
- new String(openStream.shuffleKey, StandardCharsets.UTF_8),
- new String(openStream.fileName, StandardCharsets.UTF_8))
- } else {
- val openStreamWithCredit =
decodedMsg.asInstanceOf[OpenStreamWithCredit]
- (
- new String(openStreamWithCredit.shuffleKey,
StandardCharsets.UTF_8),
- new String(openStreamWithCredit.fileName,
StandardCharsets.UTF_8))
- }
+ val pbMsg = TransportMessage.fromByteBuffer(
+
r.body().nioByteBuffer()).getParsedPayload.asInstanceOf[GeneratedMessageV3]
+ pbMsg match {
+ case pb: PbBufferStreamEnd =>
handleEndStreamFromClient(pb.getStreamId)
+ case pb: PbReadAddCredit => handleReadAddCredit(pb.getCredit,
pb.getStreamId)
+ case pb: PbOpenStream =>
+ val (shuffleKey, fileName, startIndex, endIndex, initialCredit,
readLocalShuffle) =
+ (
+ pb.getShuffleKey,
+ pb.getFileName,
+ pb.getStartIndex,
+ pb.getEndIndex,
+ pb.getInitialCredit,
+ pb.getReadLocalShuffle)
streamShuffleKey = shuffleKey
- streamFileName = fileName
- var startIndex = 0
- var endIndex = 0
- var initialCredit = 0
- getRawFileInfo(shuffleKey, fileName).getPartitionType match {
- case PartitionType.REDUCE =>
- startIndex =
decodedMsg.asInstanceOf[OpenStream].startMapIndex
- endIndex = decodedMsg.asInstanceOf[OpenStream].endMapIndex
- case PartitionType.MAP =>
- initialCredit =
decodedMsg.asInstanceOf[OpenStreamWithCredit].initialCredit
- startIndex =
decodedMsg.asInstanceOf[OpenStreamWithCredit].startIndex
- endIndex =
decodedMsg.asInstanceOf[OpenStreamWithCredit].endIndex
- case PartitionType.MAPGROUP =>
- }
+ workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME,
streamShuffleKey)
handleOpenStreamInternal(
client,
shuffleKey,
@@ -164,14 +123,63 @@ class FetchHandler(val conf: CelebornConf, val
transportConf: TransportConf)
endIndex,
initialCredit,
r,
- true)
- } catch {
- case e: IOException =>
- handleRpcIOException(client, r.requestId, streamShuffleKey,
streamFileName, e)
+ false,
+ readLocalShuffle)
+ }
+ } catch {
+ case _: Exception =>
+ logDebug("Legacy RPCs")
+ val decodedMsg = Message.decode(r.body().nioByteBuffer())
+ val msgType = decodedMsg.`type`()
+ if (msgType == Type.OPEN_STREAM || msgType ==
Type.OPEN_STREAM_WITH_CREDIT) {
+ var streamFileName: String = null
+ try {
+ val (shuffleKey, fileName) =
+ if (msgType == Type.OPEN_STREAM) {
+ val openStream = decodedMsg.asInstanceOf[OpenStream]
+ (
+ new String(openStream.shuffleKey,
StandardCharsets.UTF_8),
+ new String(openStream.fileName, StandardCharsets.UTF_8))
+ } else {
+ val openStreamWithCredit =
decodedMsg.asInstanceOf[OpenStreamWithCredit]
+ (
+ new String(openStreamWithCredit.shuffleKey,
StandardCharsets.UTF_8),
+ new String(openStreamWithCredit.fileName,
StandardCharsets.UTF_8))
+ }
+ streamShuffleKey = shuffleKey
+ streamFileName = fileName
+ var startIndex = 0
+ var endIndex = 0
+ var initialCredit = 0
+ getRawFileInfo(shuffleKey, fileName).getPartitionType match {
+ case PartitionType.REDUCE =>
+ startIndex =
decodedMsg.asInstanceOf[OpenStream].startMapIndex
+ endIndex = decodedMsg.asInstanceOf[OpenStream].endMapIndex
+ case PartitionType.MAP =>
+ initialCredit =
decodedMsg.asInstanceOf[OpenStreamWithCredit].initialCredit
+ startIndex =
decodedMsg.asInstanceOf[OpenStreamWithCredit].startIndex
+ endIndex =
decodedMsg.asInstanceOf[OpenStreamWithCredit].endIndex
+ case PartitionType.MAPGROUP =>
+ }
+ handleOpenStreamInternal(
+ client,
+ shuffleKey,
+ fileName,
+ startIndex,
+ endIndex,
+ initialCredit,
+ r,
+ true)
+ } catch {
+ case e: IOException =>
+ handleRpcIOException(client, r.requestId, streamShuffleKey,
streamFileName, e)
+ }
}
} finally {
r.body().release()
- workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME,
streamShuffleKey)
+ if (streamShuffleKey != null) {
+ workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME,
streamShuffleKey)
+ }
}
case unknown: RequestMessage =>
throw new IllegalArgumentException(s"Unknown message type id:
${unknown.`type`.id}")
@@ -296,17 +304,24 @@ class FetchHandler(val conf: CelebornConf, val
transportConf: TransportConf)
logError(
s"Read file: $fileName with shuffleKey: $shuffleKey error from
${NettyUtils.getRemoteAddress(client.getChannel)}",
ioe)
+ handleRpcException(client, requestId, ioe)
+ }
+
+ private def handleRpcException(
+ client: TransportClient,
+ requestId: Long,
+ ioe: IOException): Unit = {
client.getChannel.writeAndFlush(new RpcFailure(
requestId,
Throwables.getStackTraceAsString(ExceptionUtils.wrapIOExceptionToUnRetryable(ioe))))
}
- def handleEndStreamFromClient(req: BufferStreamEnd): Unit = {
- creditStreamManager.notifyStreamEndByClient(req.getStreamId)
+ def handleEndStreamFromClient(streamId: Long): Unit = {
+ creditStreamManager.notifyStreamEndByClient(streamId)
}
- def handleReadAddCredit(req: ReadAddCredit): Unit = {
- creditStreamManager.addCredit(req.getCredit, req.getStreamId)
+ def handleReadAddCredit(credit: Int, streamId: Long): Unit = {
+ creditStreamManager.addCredit(credit, streamId)
}
def handleChunkFetchRequest(client: TransportClient, req:
ChunkFetchRequest): Unit = {