This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch branch-0.3
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/branch-0.3 by this push:
     new 00251aca9 [CELEBORN-770][FLINK] Convert BacklogAnnouncement, 
BufferStreamEnd, ReadAddCredit to PB
00251aca9 is described below

commit 00251aca999d2cbcfdbedc7af509d280e954d42a
Author: SteNicholas <[email protected]>
AuthorDate: Mon Sep 25 10:44:48 2023 +0800

    [CELEBORN-770][FLINK] Convert BacklogAnnouncement, BufferStreamEnd, 
ReadAddCredit to PB
    
    `BacklogAnnouncement`, `BufferStreamEnd`, and `ReadAddCredit` should merge 
to transport messages to enhance celeborn's compatibility.
    
    1. Improves celeborn's transport flexibility to change RPC.
    2. Makes Compatible with 0.2 client.
    
    No.
    
    - `TransportFrameDecoderWithBufferSupplierSuiteJ`
    
    Closes #1905 from SteNicholas/CELEBORN-770.
    
    Authored-by: SteNicholas <[email protected]>
    Signed-off-by: Shuang <[email protected]>
    (cherry picked from commit 2407cae43ab44b6d7a7394736e7c12cbbd51ebb5)
    Signed-off-by: Shuang <[email protected]>
---
 .../plugin/flink/RemoteBufferStreamReader.java     |   9 +-
 .../plugin/flink/network/ReadClientHandler.java    |  32 +++--
 .../flink/readclient/CelebornBufferStream.java     |  46 ++++---
 ...nsportFrameDecoderWithBufferSupplierSuiteJ.java |  26 +++-
 .../common/network/client/TransportClient.java     |  16 +++
 .../network/protocol/BacklogAnnouncement.java      |   6 +
 .../common/network/protocol/BufferStreamEnd.java   |   6 +
 .../common/network/protocol/ReadAddCredit.java     |   1 +
 .../common/network/protocol/TransportMessage.java  |  12 ++
 common/src/main/proto/TransportMessages.proto      |  23 +++-
 .../worker/storage/MapDataPartitionReader.java     |  20 ++-
 .../service/deploy/worker/FetchHandler.scala       | 153 +++++++++++----------
 12 files changed, 245 insertions(+), 105 deletions(-)

diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
index e960495bb..51dadf9f1 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
@@ -25,10 +25,10 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
 import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
-import org.apache.celeborn.common.network.protocol.ReadAddCredit;
 import org.apache.celeborn.common.network.protocol.RequestMessage;
 import org.apache.celeborn.common.network.protocol.TransportableError;
 import org.apache.celeborn.common.network.util.NettyUtils;
+import org.apache.celeborn.common.protocol.PbReadAddCredit;
 import org.apache.celeborn.plugin.flink.buffer.CreditListener;
 import org.apache.celeborn.plugin.flink.buffer.TransferBufferPool;
 import org.apache.celeborn.plugin.flink.protocol.ReadData;
@@ -115,8 +115,11 @@ public class RemoteBufferStreamReader extends 
CreditListener {
 
   public void notifyAvailableCredits(int numCredits) {
     if (!closed) {
-      ReadAddCredit addCredit = new ReadAddCredit(bufferStream.getStreamId(), 
numCredits);
-      bufferStream.addCredit(addCredit);
+      bufferStream.addCredit(
+          PbReadAddCredit.newBuilder()
+              .setStreamId(bufferStream.getStreamId())
+              .setCredit(numCredits)
+              .build());
     }
   }
 
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
index 9340334a9..5c100002c 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
@@ -17,6 +17,10 @@
 
 package org.apache.celeborn.plugin.flink.network;
 
+import static 
org.apache.celeborn.common.protocol.MessageType.BACKLOG_ANNOUNCEMENT_VALUE;
+import static 
org.apache.celeborn.common.protocol.MessageType.BUFFER_STREAM_END_VALUE;
+
+import java.io.IOException;
 import java.nio.charset.StandardCharsets;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Consumer;
@@ -28,6 +32,7 @@ import 
org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
 import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
 import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
 import org.apache.celeborn.common.network.protocol.TransportableError;
 import org.apache.celeborn.common.network.server.BaseMessageHandler;
 import org.apache.celeborn.common.util.JavaUtils;
@@ -66,32 +71,43 @@ public class ReadClientHandler extends BaseMessageHandler {
 
   @Override
   public void receive(TransportClient client, RequestMessage msg) {
-    long streamId = 0;
     switch (msg.type()) {
       case READ_DATA:
         ReadData readData = (ReadData) msg;
-        streamId = readData.getStreamId();
-        processMessageInternal(streamId, readData);
+        processMessageInternal(readData.getStreamId(), readData);
         break;
       case BACKLOG_ANNOUNCEMENT:
         BacklogAnnouncement backlogAnnouncement = (BacklogAnnouncement) msg;
-        streamId = backlogAnnouncement.getStreamId();
-        processMessageInternal(streamId, backlogAnnouncement);
+        processMessageInternal(backlogAnnouncement.getStreamId(), 
backlogAnnouncement);
         break;
       case TRANSPORTABLE_ERROR:
         TransportableError transportableError = ((TransportableError) msg);
-        streamId = transportableError.getStreamId();
         logger.warn(
             "Received TransportableError from worker {} with content {}",
             client.getSocketAddress().toString(),
             transportableError.getErrorMessage());
-        processMessageInternal(streamId, transportableError);
+        processMessageInternal(transportableError.getStreamId(), 
transportableError);
         break;
       case BUFFER_STREAM_END:
         BufferStreamEnd streamEnd = (BufferStreamEnd) msg;
-        logger.debug("Received streamend for {}", streamEnd.getStreamId());
         processMessageInternal(streamEnd.getStreamId(), streamEnd);
         break;
+      case RPC_REQUEST:
+        try {
+          TransportMessage transportMessage =
+              TransportMessage.fromByteBuffer(msg.body().nioByteBuffer());
+          switch (transportMessage.getMessageTypeValue()) {
+            case BACKLOG_ANNOUNCEMENT_VALUE:
+              receive(client, 
BacklogAnnouncement.fromProto(transportMessage.getParsedPayload()));
+              break;
+            case BUFFER_STREAM_END_VALUE:
+              receive(client, 
BufferStreamEnd.fromProto(transportMessage.getParsedPayload()));
+              break;
+          }
+        } catch (IOException e) {
+          logger.warn("Failed to process RpcRequest message {}. ", msg, e);
+        }
+        break;
       case ONE_WAY_MESSAGE:
         // ignore it.
         break;
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
index 4fc9d7384..fcd9e0458 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
@@ -29,11 +29,15 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.common.network.client.RpcResponseCallback;
 import org.apache.celeborn.common.network.client.TransportClient;
-import org.apache.celeborn.common.network.protocol.*;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import org.apache.celeborn.common.network.protocol.TransportableError;
 import org.apache.celeborn.common.network.util.NettyUtils;
 import org.apache.celeborn.common.protocol.MessageType;
 import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
 import org.apache.celeborn.common.protocol.PbOpenStream;
+import org.apache.celeborn.common.protocol.PbReadAddCredit;
 import org.apache.celeborn.common.protocol.PbStreamHandler;
 import org.apache.celeborn.plugin.flink.network.FlinkTransportClientFactory;
 
@@ -82,21 +86,25 @@ public class CelebornBufferStream {
     moveToNextPartitionIfPossible(0);
   }
 
-  public void addCredit(ReadAddCredit addCredit) {
-    this.client
-        .getChannel()
-        .writeAndFlush(addCredit)
-        .addListener(
-            future -> {
-              if (future.isSuccess()) {
-                // Send ReadAddCredit do not expect response.
-              } else {
-                logger.warn(
-                    "Send ReadAddCredit to {} failed, detail {}",
-                    this.client.getSocketAddress().toString(),
-                    future.cause());
-              }
-            });
+  public void addCredit(PbReadAddCredit pbReadAddCredit) {
+    this.client.sendRpc(
+        new TransportMessage(MessageType.READ_ADD_CREDIT, 
pbReadAddCredit.toByteArray())
+            .toByteBuffer(),
+        new RpcResponseCallback() {
+
+          @Override
+          public void onSuccess(ByteBuffer response) {
+            // Send PbReadAddCredit do not expect response.
+          }
+
+          @Override
+          public void onFailure(Throwable e) {
+            logger.warn(
+                "Send PbReadAddCredit to {} failed, detail {}",
+                NettyUtils.getRemoteAddress(client.getChannel()),
+                e.getCause());
+          }
+        });
   }
 
   public static CelebornBufferStream empty() {
@@ -127,7 +135,11 @@ public class CelebornBufferStream {
 
   private void closeStream(long streamId) {
     if (client != null && client.isActive()) {
-      client.getChannel().writeAndFlush(new BufferStreamEnd(streamId));
+      client.sendRpc(
+          new TransportMessage(
+                  MessageType.BUFFER_STREAM_END,
+                  
PbBufferStreamEnd.newBuilder().setStreamId(streamId).build().toByteArray())
+              .toByteBuffer());
     }
   }
 
diff --git 
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
 
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
index 073972b44..64696abec 100644
--- 
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
+++ 
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
@@ -17,6 +17,8 @@
 
 package org.apache.celeborn.plugin.flink.network;
 
+import static 
org.apache.celeborn.common.network.client.TransportClient.requestId;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
@@ -31,9 +33,13 @@ import org.junit.Assert;
 import org.junit.Test;
 import org.mockito.Mockito;
 
-import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
+import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
 import org.apache.celeborn.common.network.protocol.Message;
 import org.apache.celeborn.common.network.protocol.ReadData;
+import org.apache.celeborn.common.network.protocol.RpcRequest;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import org.apache.celeborn.common.protocol.MessageType;
+import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
 import org.apache.celeborn.common.util.JavaUtils;
 
 public class TransportFrameDecoderWithBufferSupplierSuiteJ {
@@ -57,10 +63,10 @@ public class TransportFrameDecoderWithBufferSupplierSuiteJ {
         new TransportFrameDecoderWithBufferSupplier(supplier);
     ChannelHandlerContext context = Mockito.mock(ChannelHandlerContext.class);
 
-    BacklogAnnouncement announcement = new BacklogAnnouncement(0, 0);
+    RpcRequest announcement = createBacklogAnnouncement(0, 0);
     ReadData unUsedReadData = new ReadData(1, generateData(1024));
     ReadData readData = new ReadData(2, generateData(1024));
-    BacklogAnnouncement announcement1 = new BacklogAnnouncement(0, 0);
+    RpcRequest announcement1 = createBacklogAnnouncement(0, 0);
     ReadData unUsedReadData1 = new ReadData(1, generateData(1024));
     ReadData readData1 = new ReadData(2, generateData(8));
 
@@ -102,6 +108,20 @@ public class TransportFrameDecoderWithBufferSupplierSuiteJ 
{
     Assert.assertEquals(buffers.size(), 6);
   }
 
+  public RpcRequest createBacklogAnnouncement(long streamId, int backlog) {
+    return new RpcRequest(
+        requestId(),
+        new NioManagedBuffer(
+            new TransportMessage(
+                    MessageType.BACKLOG_ANNOUNCEMENT,
+                    PbBacklogAnnouncement.newBuilder()
+                        .setStreamId(streamId)
+                        .setBacklog(backlog)
+                        .build()
+                        .toByteArray())
+                .toByteBuffer()));
+  }
+
   public ByteBuf encodeMessage(Message in, ByteBuf byteBuf) throws IOException 
{
     byteBuf.writeInt(in.encodedLength());
     in.type().encode(byteBuf);
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
index 2c9eca4cb..697ca2d26 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
@@ -170,6 +170,22 @@ public class TransportClient implements Closeable {
     return requestId;
   }
 
+  /**
+   * Sends an opaque message to the RpcHandler on the server-side.
+   *
+   * @param message The message to send.
+   * @return The RPC's id.
+   */
+  public long sendRpc(ByteBuffer message) {
+    if (logger.isTraceEnabled()) {
+      logger.trace("Sending RPC to {}", NettyUtils.getRemoteAddress(channel));
+    }
+
+    long requestId = requestId();
+    channel.writeAndFlush(new RpcRequest(requestId, new 
NioManagedBuffer(message)));
+    return requestId;
+  }
+
   public ChannelFuture pushData(
       PushData pushData, long pushDataTimeout, RpcResponseCallback callback) {
     return pushData(pushData, pushDataTimeout, callback, null);
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/BacklogAnnouncement.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/BacklogAnnouncement.java
index 45f02f5d8..daccb7bf2 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/BacklogAnnouncement.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/BacklogAnnouncement.java
@@ -21,6 +21,8 @@ import static 
org.apache.celeborn.common.network.protocol.Message.Type.BACKLOG_A
 
 import io.netty.buffer.ByteBuf;
 
+import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
+
 // This RPC is sent to flink plugin to tell flink client to be ready for 
buffers.
 public class BacklogAnnouncement extends RequestMessage {
   private long streamId;
@@ -60,4 +62,8 @@ public class BacklogAnnouncement extends RequestMessage {
   public int getBacklog() {
     return backlog;
   }
+
+  public static BacklogAnnouncement fromProto(PbBacklogAnnouncement pb) {
+    return new BacklogAnnouncement(pb.getStreamId(), pb.getBacklog());
+  }
 }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/BufferStreamEnd.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/BufferStreamEnd.java
index d85e380d1..8b86fa547 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/BufferStreamEnd.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/BufferStreamEnd.java
@@ -19,6 +19,8 @@ package org.apache.celeborn.common.network.protocol;
 
 import io.netty.buffer.ByteBuf;
 
+import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
+
 public class BufferStreamEnd extends RequestMessage {
   private long streamId;
 
@@ -49,4 +51,8 @@ public class BufferStreamEnd extends RequestMessage {
   public long getStreamId() {
     return streamId;
   }
+
+  public static BufferStreamEnd fromProto(PbBufferStreamEnd pb) {
+    return new BufferStreamEnd(pb.getStreamId());
+  }
 }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java
index ca34a5c17..27fa54288 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java
@@ -20,6 +20,7 @@ import java.util.Objects;
 
 import io.netty.buffer.ByteBuf;
 
+@Deprecated
 public class ReadAddCredit extends RequestMessage {
   private long streamId;
   private int credit;
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index d72bfec1b..769eef0d4 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -17,7 +17,10 @@
 
 package org.apache.celeborn.common.network.protocol;
 
+import static 
org.apache.celeborn.common.protocol.MessageType.BACKLOG_ANNOUNCEMENT_VALUE;
+import static 
org.apache.celeborn.common.protocol.MessageType.BUFFER_STREAM_END_VALUE;
 import static 
org.apache.celeborn.common.protocol.MessageType.OPEN_STREAM_VALUE;
+import static 
org.apache.celeborn.common.protocol.MessageType.READ_ADD_CREDIT_VALUE;
 import static 
org.apache.celeborn.common.protocol.MessageType.STREAM_HANDLER_VALUE;
 
 import java.io.Serializable;
@@ -30,7 +33,10 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.common.exception.CelebornIOException;
 import org.apache.celeborn.common.protocol.MessageType;
+import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
+import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
 import org.apache.celeborn.common.protocol.PbOpenStream;
+import org.apache.celeborn.common.protocol.PbReadAddCredit;
 import org.apache.celeborn.common.protocol.PbStreamHandler;
 
 public class TransportMessage implements Serializable {
@@ -64,6 +70,12 @@ public class TransportMessage implements Serializable {
         return (T) PbOpenStream.parseFrom(payload);
       case STREAM_HANDLER_VALUE:
         return (T) PbStreamHandler.parseFrom(payload);
+      case BACKLOG_ANNOUNCEMENT_VALUE:
+        return (T) PbBacklogAnnouncement.parseFrom(payload);
+      case BUFFER_STREAM_END_VALUE:
+        return (T) PbBufferStreamEnd.parseFrom(payload);
+      case READ_ADD_CREDIT_VALUE:
+        return (T) PbReadAddCredit.parseFrom(payload);
       default:
         logger.error("Unexpected type {}", type);
     }
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index 25cbeff50..e31725edc 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -76,6 +76,9 @@ enum MessageType {
   CHECK_WORKERS_AVAILABLE = 53;
   CHECK_WORKERS_AVAILABLE_RESPONSE = 54;
   REMOVE_WORKERS_UNAVAILABLE_INFO = 55;
+  BACKLOG_ANNOUNCEMENT = 59;
+  BUFFER_STREAM_END = 60;
+  READ_ADD_CREDIT = 61;
 }
 
 message PbStorageInfo {
@@ -505,8 +508,22 @@ message PbOpenStream {
 }
 
 message PbStreamHandler {
-  int64 streamId = 1 ;
+  int64 streamId = 1;
   int32 numChunks = 2;
-  repeated int64 chunkOffsets = 3 ;
+  repeated int64 chunkOffsets = 3;
   string fullPath = 4;
-}
\ No newline at end of file
+}
+
+message PbBacklogAnnouncement {
+  int64 streamId = 1;
+  int32 backlog = 2;
+}
+
+message PbBufferStreamEnd {
+  int64 streamId = 1;
+}
+
+message PbReadAddCredit {
+  int64 streamId = 1;
+  int32 credit = 2;
+}
diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapDataPartitionReader.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapDataPartitionReader.java
index b8f996fe7..cecd47900 100644
--- 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapDataPartitionReader.java
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapDataPartitionReader.java
@@ -17,6 +17,8 @@
 
 package org.apache.celeborn.service.deploy.worker.storage;
 
+import static 
org.apache.celeborn.common.network.client.TransportClient.requestId;
+
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.nio.channels.ClosedChannelException;
@@ -36,11 +38,15 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.common.exception.FileCorruptedException;
 import org.apache.celeborn.common.meta.FileInfo;
+import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
 import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
-import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
 import org.apache.celeborn.common.network.protocol.ReadData;
+import org.apache.celeborn.common.network.protocol.RpcRequest;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
 import org.apache.celeborn.common.network.protocol.TransportableError;
 import org.apache.celeborn.common.network.util.NettyUtils;
+import org.apache.celeborn.common.protocol.MessageType;
+import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
 import org.apache.celeborn.common.util.ExceptionUtils;
 import org.apache.celeborn.common.util.Utils;
 import org.apache.celeborn.service.deploy.worker.memory.BufferQueue;
@@ -442,7 +448,17 @@ public class MapDataPartitionReader implements 
Comparable<MapDataPartitionReader
         // old client can't support BufferStreamEnd, so for new client it 
tells client that this
         // stream is finished.
         if (fileInfo.isPartitionSplitEnabled() && !errorNotified)
-          associatedChannel.writeAndFlush(new BufferStreamEnd(streamId));
+          associatedChannel.writeAndFlush(
+              new RpcRequest(
+                  requestId(),
+                  new NioManagedBuffer(
+                      new TransportMessage(
+                              MessageType.BUFFER_STREAM_END,
+                              PbBufferStreamEnd.newBuilder()
+                                  .setStreamId(streamId)
+                                  .build()
+                                  .toByteArray())
+                          .toByteBuffer())));
         if (!buffersToSend.isEmpty()) {
           numInUseBuffers.addAndGet(-1 * buffersToSend.size());
           buffersToSend.forEach(RecyclableBuffer::recycle);
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index 019c8871e..f7e0e51a6 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -26,6 +26,7 @@ import java.util.function.Consumer
 import scala.collection.JavaConverters.asScalaBufferConverter
 
 import com.google.common.base.Throwables
+import com.google.protobuf.GeneratedMessageV3
 import io.netty.util.concurrent.{Future, GenericFutureListener}
 
 import org.apache.celeborn.common.CelebornConf
@@ -38,7 +39,7 @@ import org.apache.celeborn.common.network.protocol._
 import org.apache.celeborn.common.network.protocol.Message.Type
 import org.apache.celeborn.common.network.server.BaseMessageHandler
 import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
-import org.apache.celeborn.common.protocol.{MessageType, PartitionType, 
PbOpenStream, PbStreamHandler}
+import org.apache.celeborn.common.protocol.{MessageType, PartitionType, 
PbBufferStreamEnd, PbOpenStream, PbReadAddCredit, PbStreamHandler}
 import org.apache.celeborn.common.util.{ExceptionUtils, Utils}
 import org.apache.celeborn.service.deploy.worker.storage.{ChunkStreamManager, 
CreditStreamManager, PartitionFilesSorter, StorageManager}
 
@@ -90,72 +91,30 @@ class FetchHandler(val conf: CelebornConf, val 
transportConf: TransportConf)
   override def receive(client: TransportClient, msg: RequestMessage): Unit = {
     msg match {
       case r: BufferStreamEnd =>
-        handleEndStreamFromClient(r)
+        handleEndStreamFromClient(r.getStreamId)
       case r: ReadAddCredit =>
-        handleReadAddCredit(r)
+        handleReadAddCredit(r.getCredit, r.getStreamId)
       case r: ChunkFetchRequest =>
         handleChunkFetchRequest(client, r)
       case r: RpcRequest =>
-        // process PbOpenStream RPC
         var streamShuffleKey: String = null
-        var streamFileName: String = null
         try {
-          val pbMsg = TransportMessage.fromByteBuffer(r.body().nioByteBuffer())
-          val pbOpenStream = pbMsg.getParsedPayload[PbOpenStream]
-          val (shuffleKey, fileName, startIndex, endIndex, initialCredit, 
readLocalShuffle) =
-            (
-              pbOpenStream.getShuffleKey,
-              pbOpenStream.getFileName,
-              pbOpenStream.getStartIndex,
-              pbOpenStream.getEndIndex,
-              pbOpenStream.getInitialCredit,
-              pbOpenStream.getReadLocalShuffle)
-          streamShuffleKey = shuffleKey
-          streamFileName = fileName
-          workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME, 
streamShuffleKey)
-          handleOpenStreamInternal(
-            client,
-            shuffleKey,
-            fileName,
-            startIndex,
-            endIndex,
-            initialCredit,
-            r,
-            false,
-            readLocalShuffle)
-        } catch {
-          case _: Exception =>
-            // process legacy OpenStream RPCs
-            logDebug("Open stream with legacy RPCs")
-            try {
-              val decodedMsg = Message.decode(r.body().nioByteBuffer())
-              val (shuffleKey, fileName) =
-                if (decodedMsg.`type`() == Type.OPEN_STREAM) {
-                  val openStream = decodedMsg.asInstanceOf[OpenStream]
-                  (
-                    new String(openStream.shuffleKey, StandardCharsets.UTF_8),
-                    new String(openStream.fileName, StandardCharsets.UTF_8))
-                } else {
-                  val openStreamWithCredit = 
decodedMsg.asInstanceOf[OpenStreamWithCredit]
-                  (
-                    new String(openStreamWithCredit.shuffleKey, 
StandardCharsets.UTF_8),
-                    new String(openStreamWithCredit.fileName, 
StandardCharsets.UTF_8))
-                }
+          val pbMsg = TransportMessage.fromByteBuffer(
+            
r.body().nioByteBuffer()).getParsedPayload.asInstanceOf[GeneratedMessageV3]
+          pbMsg match {
+            case pb: PbBufferStreamEnd => 
handleEndStreamFromClient(pb.getStreamId)
+            case pb: PbReadAddCredit => handleReadAddCredit(pb.getCredit, 
pb.getStreamId)
+            case pb: PbOpenStream =>
+              val (shuffleKey, fileName, startIndex, endIndex, initialCredit, 
readLocalShuffle) =
+                (
+                  pb.getShuffleKey,
+                  pb.getFileName,
+                  pb.getStartIndex,
+                  pb.getEndIndex,
+                  pb.getInitialCredit,
+                  pb.getReadLocalShuffle)
               streamShuffleKey = shuffleKey
-              streamFileName = fileName
-              var startIndex = 0
-              var endIndex = 0
-              var initialCredit = 0
-              getRawFileInfo(shuffleKey, fileName).getPartitionType match {
-                case PartitionType.REDUCE =>
-                  startIndex = 
decodedMsg.asInstanceOf[OpenStream].startMapIndex
-                  endIndex = decodedMsg.asInstanceOf[OpenStream].endMapIndex
-                case PartitionType.MAP =>
-                  initialCredit = 
decodedMsg.asInstanceOf[OpenStreamWithCredit].initialCredit
-                  startIndex = 
decodedMsg.asInstanceOf[OpenStreamWithCredit].startIndex
-                  endIndex = 
decodedMsg.asInstanceOf[OpenStreamWithCredit].endIndex
-                case PartitionType.MAPGROUP =>
-              }
+              workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME, 
streamShuffleKey)
               handleOpenStreamInternal(
                 client,
                 shuffleKey,
@@ -164,14 +123,63 @@ class FetchHandler(val conf: CelebornConf, val 
transportConf: TransportConf)
                 endIndex,
                 initialCredit,
                 r,
-                true)
-            } catch {
-              case e: IOException =>
-                handleRpcIOException(client, r.requestId, streamShuffleKey, 
streamFileName, e)
+                false,
+                readLocalShuffle)
+          }
+        } catch {
+          case _: Exception =>
+            logDebug("Legacy RPCs")
+            val decodedMsg = Message.decode(r.body().nioByteBuffer())
+            val msgType = decodedMsg.`type`()
+            if (msgType == Type.OPEN_STREAM || msgType == 
Type.OPEN_STREAM_WITH_CREDIT) {
+              var streamFileName: String = null
+              try {
+                val (shuffleKey, fileName) =
+                  if (msgType == Type.OPEN_STREAM) {
+                    val openStream = decodedMsg.asInstanceOf[OpenStream]
+                    (
+                      new String(openStream.shuffleKey, 
StandardCharsets.UTF_8),
+                      new String(openStream.fileName, StandardCharsets.UTF_8))
+                  } else {
+                    val openStreamWithCredit = 
decodedMsg.asInstanceOf[OpenStreamWithCredit]
+                    (
+                      new String(openStreamWithCredit.shuffleKey, 
StandardCharsets.UTF_8),
+                      new String(openStreamWithCredit.fileName, 
StandardCharsets.UTF_8))
+                  }
+                streamShuffleKey = shuffleKey
+                streamFileName = fileName
+                var startIndex = 0
+                var endIndex = 0
+                var initialCredit = 0
+                getRawFileInfo(shuffleKey, fileName).getPartitionType match {
+                  case PartitionType.REDUCE =>
+                    startIndex = 
decodedMsg.asInstanceOf[OpenStream].startMapIndex
+                    endIndex = decodedMsg.asInstanceOf[OpenStream].endMapIndex
+                  case PartitionType.MAP =>
+                    initialCredit = 
decodedMsg.asInstanceOf[OpenStreamWithCredit].initialCredit
+                    startIndex = 
decodedMsg.asInstanceOf[OpenStreamWithCredit].startIndex
+                    endIndex = 
decodedMsg.asInstanceOf[OpenStreamWithCredit].endIndex
+                  case PartitionType.MAPGROUP =>
+                }
+                handleOpenStreamInternal(
+                  client,
+                  shuffleKey,
+                  fileName,
+                  startIndex,
+                  endIndex,
+                  initialCredit,
+                  r,
+                  true)
+              } catch {
+                case e: IOException =>
+                  handleRpcIOException(client, r.requestId, streamShuffleKey, 
streamFileName, e)
+              }
             }
         } finally {
           r.body().release()
-          workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, 
streamShuffleKey)
+          if (streamShuffleKey != null) {
+            workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, 
streamShuffleKey)
+          }
         }
       case unknown: RequestMessage =>
         throw new IllegalArgumentException(s"Unknown message type id: 
${unknown.`type`.id}")
@@ -296,17 +304,24 @@ class FetchHandler(val conf: CelebornConf, val 
transportConf: TransportConf)
     logError(
       s"Read file: $fileName with shuffleKey: $shuffleKey error from 
${NettyUtils.getRemoteAddress(client.getChannel)}",
       ioe)
+    handleRpcException(client, requestId, ioe)
+  }
+
+  private def handleRpcException(
+      client: TransportClient,
+      requestId: Long,
+      ioe: IOException): Unit = {
     client.getChannel.writeAndFlush(new RpcFailure(
       requestId,
       
Throwables.getStackTraceAsString(ExceptionUtils.wrapIOExceptionToUnRetryable(ioe))))
   }
 
-  def handleEndStreamFromClient(req: BufferStreamEnd): Unit = {
-    creditStreamManager.notifyStreamEndByClient(req.getStreamId)
+  def handleEndStreamFromClient(streamId: Long): Unit = {
+    creditStreamManager.notifyStreamEndByClient(streamId)
   }
 
-  def handleReadAddCredit(req: ReadAddCredit): Unit = {
-    creditStreamManager.addCredit(req.getCredit, req.getStreamId)
+  def handleReadAddCredit(credit: Int, streamId: Long): Unit = {
+    creditStreamManager.addCredit(credit, streamId)
   }
 
   def handleChunkFetchRequest(client: TransportClient, req: 
ChunkFetchRequest): Unit = {

Reply via email to