This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new ea39a9372 [CELEBORN-760] Convert OpenStream and StreamHandler to Pb
ea39a9372 is described below

commit ea39a9372aec1995c7bc38dcadc38a41127bc50b
Author: mingji <[email protected]>
AuthorDate: Sat Aug 5 13:58:08 2023 +0800

    [CELEBORN-760] Convert OpenStream and StreamHandler to Pb
    
    ### What changes were proposed in this pull request?
    Merge OpenStream and StreamHandler to transport messages to enhance 
celeborn's compatibility.
    
    ### Why are the changes needed?
    1. Improve flexibility to change RPC.
    2. Compatible with 0.2 client.
    
    ### Does this PR introduce _any_ user-facing change?
    NO.
    
    ### How was this patch tested?
    UT and cluster.
    
    Closes #1750 from FMX/CELEBORN-760.
    
    Lead-authored-by: mingji <[email protected]>
    Co-authored-by: Ethan Feng <[email protected]>
    Co-authored-by: Keyong Zhou <[email protected]>
    Co-authored-by: Keyong Zhou <[email protected]>
    Co-authored-by: Keyong Zhou <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../celeborn/client/read/DfsPartitionReader.java   |  21 ++-
 .../client/read/WorkerPartitionReader.java         |  36 ++--
 .../celeborn/common/network/protocol/Message.java  |   1 +
 .../common/network/protocol/OpenStream.java        |   6 +-
 .../network/protocol/OpenStreamWithCredit.java     |   1 +
 .../common/network/protocol/StreamHandle.java      |   3 +-
 .../common/network/protocol/TransportMessage.java  |  47 ++++++
 common/src/main/proto/TransportMessages.proto      |  18 ++
 .../celeborn/tests/flink/HeartbeatTest.scala       |   1 -
 .../service/deploy/worker/FetchHandler.scala       | 183 ++++++++++++++-------
 .../deploy/worker/storage/FileWriterSuiteJ.java    |  35 ++--
 11 files changed, 253 insertions(+), 99 deletions(-)

diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java 
b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
index 316f11ef1..637aada7f 100644
--- 
a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
+++ 
b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
@@ -37,9 +37,10 @@ import org.apache.celeborn.client.ShuffleClient;
 import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
-import org.apache.celeborn.common.network.protocol.Message;
-import org.apache.celeborn.common.network.protocol.OpenStream;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import org.apache.celeborn.common.protocol.MessageType;
 import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.PbOpenStream;
 import org.apache.celeborn.common.util.ShuffleBlockInfoUtils;
 import org.apache.celeborn.common.util.Utils;
 
@@ -77,10 +78,18 @@ public class DfsPartitionReader implements PartitionReader {
       try {
         TransportClient client =
             clientFactory.createClient(location.getHost(), 
location.getFetchPort());
-        OpenStream openBlocks =
-            new OpenStream(shuffleKey, location.getFileName(), startMapIndex, 
endMapIndex);
-        ByteBuffer response = client.sendRpcSync(openBlocks.toByteBuffer(), 
fetchTimeoutMs);
-        Message.decode(response);
+        TransportMessage openStream =
+            new TransportMessage(
+                MessageType.OPEN_STREAM,
+                PbOpenStream.newBuilder()
+                    .setShuffleKey(shuffleKey)
+                    .setFileName(location.getFileName())
+                    .setStartIndex(startMapIndex)
+                    .setEndIndex(endMapIndex)
+                    .build()
+                    .toByteArray());
+        ByteBuffer response = client.sendRpcSync(openStream.toByteBuffer(), 
fetchTimeoutMs);
+        TransportMessage.fromByteBuffer(response).getParsedPayload();
         // Parse this message to ensure sort is done.
       } catch (IOException | InterruptedException e) {
         throw new IOException(
diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
 
b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
index e323fe6e1..c02c7dd4f 100644
--- 
a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
+++ 
b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
@@ -35,17 +35,18 @@ import 
org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
 import org.apache.celeborn.common.network.client.ChunkReceivedCallback;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
-import org.apache.celeborn.common.network.protocol.Message;
-import org.apache.celeborn.common.network.protocol.OpenStream;
-import org.apache.celeborn.common.network.protocol.StreamHandle;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import org.apache.celeborn.common.protocol.MessageType;
 import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.PbOpenStream;
+import org.apache.celeborn.common.protocol.PbStreamHandler;
 import org.apache.celeborn.common.util.ExceptionUtils;
 
 public class WorkerPartitionReader implements PartitionReader {
   private final Logger logger = 
LoggerFactory.getLogger(WorkerPartitionReader.class);
   private PartitionLocation location;
   private final TransportClientFactory clientFactory;
-  private StreamHandle streamHandle;
+  private PbStreamHandler streamHandle;
 
   private int returnedChunks;
   private int chunkIndex;
@@ -105,10 +106,19 @@ public class WorkerPartitionReader implements 
PartitionReader {
       logger.error("PartitionReader thread interrupted while creating 
client.");
       throw ie;
     }
-    OpenStream openBlocks =
-        new OpenStream(shuffleKey, location.getFileName(), startMapIndex, 
endMapIndex);
-    ByteBuffer response = client.sendRpcSync(openBlocks.toByteBuffer(), 
fetchTimeoutMs);
-    streamHandle = (StreamHandle) Message.decode(response);
+
+    TransportMessage openStreamMsg =
+        new TransportMessage(
+            MessageType.OPEN_STREAM,
+            PbOpenStream.newBuilder()
+                .setShuffleKey(shuffleKey)
+                .setFileName(location.getFileName())
+                .setStartIndex(startMapIndex)
+                .setEndIndex(endMapIndex)
+                .build()
+                .toByteArray());
+    ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(), 
fetchTimeoutMs);
+    streamHandle = 
TransportMessage.fromByteBuffer(response).getParsedPayload();
 
     this.location = location;
     this.clientFactory = clientFactory;
@@ -118,12 +128,12 @@ public class WorkerPartitionReader implements 
PartitionReader {
   }
 
   public boolean hasNext() {
-    return returnedChunks < streamHandle.numChunks;
+    return returnedChunks < streamHandle.getNumChunks();
   }
 
   public ByteBuf next() throws IOException, InterruptedException {
     checkException();
-    if (chunkIndex < streamHandle.numChunks) {
+    if (chunkIndex < streamHandle.getNumChunks()) {
       fetchChunks();
     }
     ByteBuf chunk = null;
@@ -159,7 +169,7 @@ public class WorkerPartitionReader implements 
PartitionReader {
     final int inFlight = chunkIndex - returnedChunks;
     if (inFlight < fetchMaxReqsInFlight) {
       final int toFetch =
-          Math.min(fetchMaxReqsInFlight - inFlight + 1, streamHandle.numChunks 
- chunkIndex);
+          Math.min(fetchMaxReqsInFlight - inFlight + 1, 
streamHandle.getNumChunks() - chunkIndex);
       for (int i = 0; i < toFetch; i++) {
         if (testFetch && fetchChunkRetryCnt < fetchChunkMaxRetry - 1 && 
chunkIndex == 3) {
           callback.onFailure(chunkIndex, new CelebornIOException("Test fetch 
chunk failure"));
@@ -167,12 +177,12 @@ public class WorkerPartitionReader implements 
PartitionReader {
           try {
             TransportClient client =
                 clientFactory.createClient(location.getHost(), 
location.getFetchPort());
-            client.fetchChunk(streamHandle.streamId, chunkIndex, 
fetchTimeoutMs, callback);
+            client.fetchChunk(streamHandle.getStreamId(), chunkIndex, 
fetchTimeoutMs, callback);
             chunkIndex++;
           } catch (IOException e) {
             logger.error(
                 "fetchChunk for streamId: {}, chunkIndex: {} failed.",
-                streamHandle.streamId,
+                streamHandle.getStreamId(),
                 chunkIndex,
                 e);
             ExceptionUtils.wrapAndThrowIOException(e);
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
index 989c40e50..b2aaf7354 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
@@ -230,6 +230,7 @@ public abstract class Message implements Encodable {
 
       case TRANSPORTABLE_ERROR:
         return TransportableError.decode(in);
+
       case BUFFER_STREAM_END:
         return BufferStreamEnd.decode(in);
 
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStream.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStream.java
index ff1ae1565..da5b3dad3 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStream.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStream.java
@@ -23,7 +23,11 @@ import java.util.Arrays;
 import com.google.common.base.Objects;
 import io.netty.buffer.ByteBuf;
 
-/** Request to read a set of blocks. Returns {@link StreamHandle}. */
+/**
+ * Request to read a set of blocks. Returns {@link StreamHandle}. Use 
PbOpenStream instead of this
+ * one.
+ */
+@Deprecated
 public final class OpenStream extends RequestMessage {
   public byte[] shuffleKey;
   public byte[] fileName;
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStreamWithCredit.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStreamWithCredit.java
index d59d86dfc..ea3f12a6b 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStreamWithCredit.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStreamWithCredit.java
@@ -24,6 +24,7 @@ import java.nio.charset.StandardCharsets;
 import io.netty.buffer.ByteBuf;
 
 /** Buffer stream used in Map partition scenario. */
+@Deprecated
 public final class OpenStreamWithCredit extends RequestMessage {
   public final byte[] shuffleKey;
   public final byte[] fileName;
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamHandle.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamHandle.java
index 4d50bec66..c2428728e 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamHandle.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamHandle.java
@@ -22,8 +22,9 @@ import io.netty.buffer.ByteBuf;
 
 /**
  * Identifier for a fixed number of chunks to read from a stream created by an 
"open blocks"
- * message.
+ * message. Use PbStreamHandler instead of this.
  */
+@Deprecated
 public final class StreamHandle extends RequestMessage {
   public final long streamId;
   public final int numChunks;
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index a139fc08e..d72bfec1b 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -17,12 +17,25 @@
 
 package org.apache.celeborn.common.network.protocol;
 
+import static 
org.apache.celeborn.common.protocol.MessageType.OPEN_STREAM_VALUE;
+import static 
org.apache.celeborn.common.protocol.MessageType.STREAM_HANDLER_VALUE;
+
 import java.io.Serializable;
+import java.nio.ByteBuffer;
+
+import com.google.protobuf.GeneratedMessageV3;
+import com.google.protobuf.InvalidProtocolBufferException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
+import org.apache.celeborn.common.exception.CelebornIOException;
 import org.apache.celeborn.common.protocol.MessageType;
+import org.apache.celeborn.common.protocol.PbOpenStream;
+import org.apache.celeborn.common.protocol.PbStreamHandler;
 
 public class TransportMessage implements Serializable {
   private static final long serialVersionUID = -3259000920699629773L;
+  private static Logger logger = 
LoggerFactory.getLogger(TransportMessage.class);
   @Deprecated private final MessageType type;
   private final int messageTypeValue;
   private final byte[] payload;
@@ -44,4 +57,38 @@ public class TransportMessage implements Serializable {
   public byte[] getPayload() {
     return payload;
   }
+
+  public <T extends GeneratedMessageV3> T getParsedPayload() throws 
InvalidProtocolBufferException {
+    switch (messageTypeValue) {
+      case OPEN_STREAM_VALUE:
+        return (T) PbOpenStream.parseFrom(payload);
+      case STREAM_HANDLER_VALUE:
+        return (T) PbStreamHandler.parseFrom(payload);
+      default:
+        logger.error("Unexpected type {}", type);
+    }
+    return null;
+  }
+
+  public ByteBuffer toByteBuffer() {
+    int totalBufferSize = payload.length + 4 + 4;
+    ByteBuffer buffer = ByteBuffer.allocate(totalBufferSize);
+    buffer.putInt(messageTypeValue);
+    buffer.putInt(payload.length);
+    buffer.put(payload);
+    buffer.flip();
+    return buffer;
+  }
+
+  public static TransportMessage fromByteBuffer(ByteBuffer buffer) throws 
CelebornIOException {
+    int messageTypeValue = buffer.getInt();
+    if (MessageType.forNumber(messageTypeValue) == null) {
+      throw new CelebornIOException("Decode failed, fallback to legacy 
messages.");
+    }
+    int payloadLen = buffer.getInt();
+    byte[] payload = new byte[payloadLen];
+    buffer.get(payload);
+    MessageType msgType = MessageType.forNumber(messageTypeValue);
+    return new TransportMessage(msgType, payload);
+  }
 }
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index 0454ceff1..172dc1a91 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -69,6 +69,8 @@ enum MessageType {
   REGISTER_MAP_PARTITION_TASK = 48;
   HEARTBEAT_FROM_APPLICATION_RESPONSE = 49;
   CHECK_FOR_HDFS_EXPIRED_DIRS_TIMEOUT = 50;
+  OPEN_STREAM = 51;
+  STREAM_HANDLER = 52;
 }
 
 message PbStorageInfo {
@@ -446,3 +448,19 @@ message PbSnapshotMetaInfo {
   map<string, int64> lostWorkers = 12;
   repeated PbWorkerInfo shutdownWorkers = 13;
 }
+
+message PbOpenStream {
+  string shuffleKey = 1;
+  string fileName = 2;
+  int32 startIndex = 3;
+  int32 endIndex = 4;
+  int32 initialCredit = 5;
+  bool localRead = 6;
+}
+
+message PbStreamHandler {
+  int64 streamId = 1 ;
+  int32 numChunks = 2;
+  repeated int64 chunkOffsets = 3 ;
+  string fullPath = 4;
+}
\ No newline at end of file
diff --git 
a/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
 
b/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
index 625debb4c..a373c6fd8 100644
--- 
a/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
+++ 
b/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
@@ -24,7 +24,6 @@ import org.apache.celeborn.common.identity.UserIdentifier
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl
 import org.apache.celeborn.service.deploy.{HeartbeatFeature, 
MiniClusterFeature}
-import org.apache.celeborn.service.deploy.worker.memory.MemoryManager;
 
 class HeartbeatTest extends AnyFunSuite with Logging with MiniClusterFeature 
with HeartbeatFeature
   with BeforeAndAfterAll with BeforeAndAfterEach {
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index 5426bbb77..45e700695 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -17,9 +17,9 @@
 
 package org.apache.celeborn.service.deploy.worker
 
+import java.{lang, util}
 import java.io.{FileNotFoundException, IOException}
 import java.nio.charset.StandardCharsets
-import java.util
 import java.util.concurrent.atomic.AtomicBoolean
 import java.util.function.Consumer
 
@@ -36,7 +36,7 @@ import org.apache.celeborn.common.network.protocol._
 import org.apache.celeborn.common.network.protocol.Message.Type
 import org.apache.celeborn.common.network.server.BaseMessageHandler
 import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
-import org.apache.celeborn.common.protocol.PartitionType
+import org.apache.celeborn.common.protocol.{MessageType, PartitionType, 
PbOpenStream, PbStreamHandler}
 import org.apache.celeborn.common.util.{ExceptionUtils, Utils}
 import org.apache.celeborn.service.deploy.worker.storage.{ChunkStreamManager, 
CreditStreamManager, PartitionFilesSorter, StorageManager}
 
@@ -94,51 +94,110 @@ class FetchHandler(val conf: CelebornConf, val 
transportConf: TransportConf)
       case r: ChunkFetchRequest =>
         handleChunkFetchRequest(client, r)
       case r: RpcRequest =>
-        val msg = Message.decode(r.body().nioByteBuffer())
-        handleOpenStream(client, r, msg)
+        // process PbOpenStream RPC
+        var timerShuffleKey: String = null
+        try {
+          val pbMsg = TransportMessage.fromByteBuffer(r.body().nioByteBuffer())
+          val pbOpenStream = pbMsg.getParsedPayload[PbOpenStream]
+          val (shuffleKey, fileName, startIndex, endIndex, initialCredit) =
+            (
+              pbOpenStream.getShuffleKey,
+              pbOpenStream.getFileName,
+              pbOpenStream.getStartIndex,
+              pbOpenStream.getEndIndex,
+              pbOpenStream.getInitialCredit)
+
+          timerShuffleKey = shuffleKey
+          workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME, 
timerShuffleKey)
+          handleOpenStreamInternal(
+            client,
+            shuffleKey,
+            fileName,
+            startIndex,
+            endIndex,
+            initialCredit,
+            r,
+            false)
+        } catch {
+          case _: Exception =>
+            // process legacy OpenStream RPCs
+            logDebug("Open stream with legacy RPCs")
+            try {
+              val decodedMsg = Message.decode(r.body().nioByteBuffer())
+              val (shuffleKey, fileName) =
+                if (decodedMsg.`type`() == Type.OPEN_STREAM) {
+                  val openStream = decodedMsg.asInstanceOf[OpenStream]
+                  (
+                    new String(openStream.shuffleKey, StandardCharsets.UTF_8),
+                    new String(openStream.fileName, StandardCharsets.UTF_8))
+                } else {
+                  val openStreamWithCredit = 
decodedMsg.asInstanceOf[OpenStreamWithCredit]
+                  (
+                    new String(openStreamWithCredit.shuffleKey, 
StandardCharsets.UTF_8),
+                    new String(openStreamWithCredit.fileName, 
StandardCharsets.UTF_8))
+                }
+              timerShuffleKey = shuffleKey
+              var startIndex = 0
+              var endIndex = 0
+              var initialCredit = 0
+              getRawFileInfo(shuffleKey, fileName).getPartitionType match {
+                case PartitionType.REDUCE =>
+                  startIndex = 
decodedMsg.asInstanceOf[OpenStream].startMapIndex
+                  endIndex = decodedMsg.asInstanceOf[OpenStream].endMapIndex
+                case PartitionType.MAP =>
+                  initialCredit = 
decodedMsg.asInstanceOf[OpenStreamWithCredit].initialCredit
+                  startIndex = 
decodedMsg.asInstanceOf[OpenStreamWithCredit].startIndex
+                  endIndex = 
decodedMsg.asInstanceOf[OpenStreamWithCredit].endIndex
+                case PartitionType.MAPGROUP =>
+              }
+              handleOpenStreamInternal(
+                client,
+                shuffleKey,
+                fileName,
+                startIndex,
+                endIndex,
+                initialCredit,
+                r,
+                true)
+            } catch {
+              case e: IOException =>
+                handleRpcIOException(client, r.requestId, e)
+            }
+        } finally {
+          r.body().release()
+          workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, 
timerShuffleKey)
+        }
       case unknown: RequestMessage =>
         throw new IllegalArgumentException(s"Unknown message type id: 
${unknown.`type`.id}")
     }
   }
 
-  // here are BackLogAnnouncement,OpenStream and OpenStreamWithCredit RPCs to 
handle
-  def handleOpenStream(client: TransportClient, request: RpcRequest, msg: 
Message): Unit = {
-    val (shuffleKey, fileName) =
-      if (msg.`type`() == Type.OPEN_STREAM) {
-        val openStream = msg.asInstanceOf[OpenStream]
-        (
-          new String(openStream.shuffleKey, StandardCharsets.UTF_8),
-          new String(openStream.fileName, StandardCharsets.UTF_8))
-      } else {
-        val openStreamWithCredit = msg.asInstanceOf[OpenStreamWithCredit]
-        (
-          new String(openStreamWithCredit.shuffleKey, StandardCharsets.UTF_8),
-          new String(openStreamWithCredit.fileName, StandardCharsets.UTF_8))
-      }
-    // metrics start
-    workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME, shuffleKey)
+  private def handleOpenStreamInternal(
+      client: TransportClient,
+      shuffleKey: String,
+      fileName: String,
+      startIndex: Int,
+      endIndex: Int,
+      initialCredit: Int,
+      request: RpcRequest,
+      isLegacy: Boolean): Unit = {
     try {
       var fileInfo = getRawFileInfo(shuffleKey, fileName)
-      try fileInfo.getPartitionType match {
+      fileInfo.getPartitionType match {
         case PartitionType.REDUCE =>
-          val startMapIndex = msg.asInstanceOf[OpenStream].startMapIndex
-          val endMapIndex = msg.asInstanceOf[OpenStream].endMapIndex
-          if (endMapIndex != Integer.MAX_VALUE) {
+          if (endIndex != Integer.MAX_VALUE) {
             fileInfo = partitionsSorter.getSortedFileInfo(
               shuffleKey,
               fileName,
               fileInfo,
-              startMapIndex,
-              endMapIndex)
+              startIndex,
+              endIndex)
           }
-          logDebug(s"Received chunk fetch request $shuffleKey $fileName 
$startMapIndex " +
-            s"$endMapIndex get file info $fileInfo from client channel " +
+          logDebug(s"Received chunk fetch request $shuffleKey $fileName 
$startIndex " +
+            s"$endIndex get file info $fileInfo from client channel " +
             s"${NettyUtils.getRemoteAddress(client.getChannel)}")
           if (fileInfo.isHdfs) {
-            val streamHandle = new StreamHandle(0, 0)
-            client.getChannel.writeAndFlush(new RpcResponse(
-              request.requestId,
-              new NioManagedBuffer(streamHandle.toByteBuffer)))
+            replyStreamHandler(client, request.requestId, 0, 0, isLegacy)
           } else {
             val buffers = new FileManagedBuffers(fileInfo, transportConf)
             val fetchTimeMetrics = 
storageManager.getFetchTimeMetric(fileInfo.getFile)
@@ -146,54 +205,56 @@ class FetchHandler(val conf: CelebornConf, val 
transportConf: TransportConf)
               shuffleKey,
               buffers,
               fetchTimeMetrics)
-            val streamHandle = new StreamHandle(streamId, fileInfo.numChunks())
             if (fileInfo.numChunks() == 0)
               logDebug(s"StreamId $streamId, fileName $fileName, mapRange " +
-                s"[$startMapIndex-$endMapIndex] is empty. Received from client 
channel " +
+                s"[$startIndex-$endIndex] is empty. Received from client 
channel " +
                 s"${NettyUtils.getRemoteAddress(client.getChannel)}")
             else logDebug(
               s"StreamId $streamId, fileName $fileName, numChunks 
${fileInfo.numChunks()}, " +
-                s"mapRange [$startMapIndex-$endMapIndex]. Received from client 
channel " +
+                s"mapRange [$startIndex-$endIndex]. Received from client 
channel " +
                 s"${NettyUtils.getRemoteAddress(client.getChannel)}")
-            client.getChannel.writeAndFlush(new RpcResponse(
-              request.requestId,
-              new NioManagedBuffer(streamHandle.toByteBuffer)))
+            replyStreamHandler(client, request.requestId, streamId, 
fileInfo.numChunks(), isLegacy)
           }
         case PartitionType.MAP =>
-          val initialCredit = 
msg.asInstanceOf[OpenStreamWithCredit].initialCredit
-          val startIndex = msg.asInstanceOf[OpenStreamWithCredit].startIndex
-          val endIndex = msg.asInstanceOf[OpenStreamWithCredit].endIndex
-
-          val callback = new Consumer[java.lang.Long] {
-            override def accept(streamId: java.lang.Long): Unit = {
-              val bufferStreamHandle = new StreamHandle(streamId, 0)
-              client.getChannel.writeAndFlush(new RpcResponse(
-                request.requestId,
-                new NioManagedBuffer(bufferStreamHandle.toByteBuffer)))
+          val creditStreamHandler =
+            new Consumer[java.lang.Long] {
+              override def accept(streamId: java.lang.Long): Unit = {
+                replyStreamHandler(client, request.requestId, streamId, 0, 
isLegacy)
+              }
             }
-          }
 
           creditStreamManager.registerStream(
-            callback,
+            creditStreamHandler,
             client.getChannel,
             initialCredit,
             startIndex,
             endIndex,
             fileInfo)
-
         case PartitionType.MAPGROUP =>
-      } catch {
-        case e: IOException =>
-          handleRpcIOException(client, request.requestId, e)
-      } finally {
-        // metrics end
-        workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, shuffleKey)
-        request.body().release()
       }
     } catch {
-      case ioe: IOException =>
-        workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, shuffleKey)
-        handleRpcIOException(client, request.requestId, ioe)
+      case e: IOException =>
+        handleRpcIOException(client, request.requestId, e)
+    }
+  }
+
+  private def replyStreamHandler(
+      client: TransportClient,
+      requestId: Long,
+      streamId: Long,
+      numChunks: Int,
+      isLegacy: Boolean): Unit = {
+    if (isLegacy) {
+      client.getChannel.writeAndFlush(new RpcResponse(
+        requestId,
+        new NioManagedBuffer(new StreamHandle(streamId, 
numChunks).toByteBuffer)))
+    } else {
+      client.getChannel.writeAndFlush(new RpcResponse(
+        requestId,
+        new NioManagedBuffer(new TransportMessage(
+          MessageType.STREAM_HANDLER,
+          PbStreamHandler.newBuilder.setStreamId(streamId).setNumChunks(
+            numChunks).build.toByteArray).toByteBuffer)))
     }
   }
 
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java
index 81decac0c..ac13df7c7 100644
--- 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java
@@ -55,6 +55,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.exception.CelebornException;
 import org.apache.celeborn.common.identity.UserIdentifier;
 import org.apache.celeborn.common.meta.FileInfo;
 import org.apache.celeborn.common.network.TransportContext;
@@ -62,15 +63,11 @@ import 
org.apache.celeborn.common.network.buffer.ManagedBuffer;
 import org.apache.celeborn.common.network.client.ChunkReceivedCallback;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
-import org.apache.celeborn.common.network.protocol.Message;
-import org.apache.celeborn.common.network.protocol.OpenStream;
-import org.apache.celeborn.common.network.protocol.StreamHandle;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
 import org.apache.celeborn.common.network.server.TransportServer;
 import org.apache.celeborn.common.network.util.NettyUtils;
 import org.apache.celeborn.common.network.util.TransportConf;
-import org.apache.celeborn.common.protocol.PartitionSplitMode;
-import org.apache.celeborn.common.protocol.PartitionType;
-import org.apache.celeborn.common.protocol.StorageInfo;
+import org.apache.celeborn.common.protocol.*;
 import org.apache.celeborn.common.util.JavaUtils;
 import org.apache.celeborn.common.util.ThreadUtils;
 import org.apache.celeborn.common.util.Utils;
@@ -198,19 +195,25 @@ public class FileWriterSuiteJ {
   }
 
   public ByteBuffer createOpenMessage() {
-    byte[] shuffleKeyBytes = "shuffleKey".getBytes(StandardCharsets.UTF_8);
-    byte[] fileNameBytes = "location".getBytes(StandardCharsets.UTF_8);
-
-    OpenStream openBlocks = new OpenStream(shuffleKeyBytes, fileNameBytes, 0, 
Integer.MAX_VALUE);
-
-    return openBlocks.toByteBuffer();
+    TransportMessage message =
+        new TransportMessage(
+            MessageType.OPEN_STREAM,
+            PbOpenStream.newBuilder()
+                .setShuffleKey("shuffleKey")
+                .setFileName("location")
+                .setStartIndex(0)
+                .setEndIndex(Integer.MAX_VALUE)
+                .build()
+                .toByteArray());
+
+    return message.toByteBuffer();
   }
 
-  private void setUpConn(TransportClient client) throws IOException {
+  private void setUpConn(TransportClient client) throws IOException, 
CelebornException {
     ByteBuffer resp = client.sendRpcSync(createOpenMessage(), 10000);
-    StreamHandle streamHandle = (StreamHandle) Message.decode(resp);
-    streamId = streamHandle.streamId;
-    numChunks = streamHandle.numChunks;
+    PbStreamHandler streamHandle = 
TransportMessage.fromByteBuffer(resp).getParsedPayload();
+    streamId = streamHandle.getStreamId();
+    numChunks = streamHandle.getNumChunks();
   }
 
   private FetchResult fetchChunks(TransportClient client, List<Integer> 
chunkIndices)

Reply via email to