This is an automated email from the ASF dual-hosted git repository.

ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 55e85055b [CELEBORN-771][FLINK] Convert PushDataHandShake, 
RegionFinish, RegionStart to PB
55e85055b is described below

commit 55e85055b80111b6f3662474c8fd500bf6589f2a
Author: SteNicholas <[email protected]>
AuthorDate: Fri Sep 22 11:36:45 2023 +0800

    [CELEBORN-771][FLINK] Convert PushDataHandShake, RegionFinish, RegionStart 
to PB
    
    ### What changes were proposed in this pull request?
    
    `PushDataHandShake`, `RegionFinish`, and `RegionStart` should merge to 
transport messages to enhance celeborn's compatibility.
    
    ### Why are the changes needed?
    
    1. Improves celeborn's transport flexibility to change RPC.
    2. Makes Compatible with 0.2 client.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    - `RemoteShuffleOutputGateSuiteJ`
    
    Closes #1910 from SteNicholas/CELEBORN-771.
    
    Authored-by: SteNicholas <[email protected]>
    Signed-off-by: mingji <[email protected]>
---
 .../flink/readclient/FlinkShuffleClientImpl.java   |  70 ++++++----
 .../common/network/protocol/PushDataHandShake.java |   1 +
 .../common/network/protocol/RegionFinish.java      |   1 +
 .../common/network/protocol/RegionStart.java       |   1 +
 .../common/network/protocol/TransportMessage.java  |  12 ++
 common/src/main/proto/TransportMessages.proto      |  30 ++++-
 .../service/deploy/worker/PushDataHandler.scala    | 145 +++++++++++++++++----
 7 files changed, 207 insertions(+), 53 deletions(-)

diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
index 992716894..c7ed33b33 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
@@ -44,13 +44,16 @@ import 
org.apache.celeborn.common.network.client.RpcResponseCallback;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
 import org.apache.celeborn.common.network.protocol.PushData;
-import org.apache.celeborn.common.network.protocol.PushDataHandShake;
-import org.apache.celeborn.common.network.protocol.RegionFinish;
-import org.apache.celeborn.common.network.protocol.RegionStart;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
 import org.apache.celeborn.common.network.util.TransportConf;
+import org.apache.celeborn.common.protocol.MessageType;
 import org.apache.celeborn.common.protocol.PartitionLocation;
 import org.apache.celeborn.common.protocol.PbChangeLocationPartitionInfo;
 import org.apache.celeborn.common.protocol.PbChangeLocationResponse;
+import org.apache.celeborn.common.protocol.PbPartitionLocation.Mode;
+import org.apache.celeborn.common.protocol.PbPushDataHandShake;
+import org.apache.celeborn.common.protocol.PbRegionFinish;
+import org.apache.celeborn.common.protocol.PbRegionStart;
 import org.apache.celeborn.common.protocol.ReviveRequest;
 import org.apache.celeborn.common.protocol.TransportModuleConstants;
 import org.apache.celeborn.common.protocol.message.ControlMessages;
@@ -332,18 +335,23 @@ public class FlinkShuffleClientImpl extends 
ShuffleClientImpl {
               location.getUniqueId());
           logger.debug("PushDataHandShake location {}", location);
           TransportClient client = 
createClientWaitingInFlightRequest(location, mapKey, pushState);
-          PushDataHandShake handShake =
-              new PushDataHandShake(
-                  PRIMARY_MODE,
-                  shuffleKey,
-                  location.getUniqueId(),
-                  attemptId,
-                  numPartitions,
-                  bufferSize);
           ByteBuffer pushDataHandShakeResponse;
           try {
             pushDataHandShakeResponse =
-                client.sendRpcSync(handShake.toByteBuffer(), 
conf.pushDataTimeoutMs());
+                client.sendRpcSync(
+                    new TransportMessage(
+                            MessageType.PUSH_DATA_HAND_SHAKE,
+                            PbPushDataHandShake.newBuilder()
+                                .setMode(Mode.forNumber(PRIMARY_MODE))
+                                .setShuffleKey(shuffleKey)
+                                .setPartitionUniqueId(location.getUniqueId())
+                                .setAttemptId(attemptId)
+                                .setNumPartitions(numPartitions)
+                                .setBufferSize(bufferSize)
+                                .build()
+                                .toByteArray())
+                        .toByteBuffer(),
+                    conf.pushDataTimeoutMs());
           } catch (IOException e) {
             // ioexeption revive
             return revive(shuffleId, mapId, attemptId, location);
@@ -378,18 +386,23 @@ public class FlinkShuffleClientImpl extends 
ShuffleClientImpl {
               location.getUniqueId());
           logger.debug("RegionStart  for location {}.", location.toString());
           TransportClient client = 
createClientWaitingInFlightRequest(location, mapKey, pushState);
-          RegionStart regionStart =
-              new RegionStart(
-                  PRIMARY_MODE,
-                  shuffleKey,
-                  location.getUniqueId(),
-                  attemptId,
-                  currentRegionIdx,
-                  isBroadcast);
           ByteBuffer regionStartResponse;
           try {
             regionStartResponse =
-                client.sendRpcSync(regionStart.toByteBuffer(), 
conf.pushDataTimeoutMs());
+                client.sendRpcSync(
+                    new TransportMessage(
+                            MessageType.REGION_START,
+                            PbRegionStart.newBuilder()
+                                .setMode(Mode.forNumber(PRIMARY_MODE))
+                                .setShuffleKey(shuffleKey)
+                                .setPartitionUniqueId(location.getUniqueId())
+                                .setAttemptId(attemptId)
+                                .setCurrentRegionIndex(currentRegionIdx)
+                                .setIsBroadcast(isBroadcast)
+                                .build()
+                                .toByteArray())
+                        .toByteBuffer(),
+                    conf.pushDataTimeoutMs());
           } catch (IOException e) {
             // ioexeption revive
             return revive(shuffleId, mapId, attemptId, location);
@@ -459,9 +472,18 @@ public class FlinkShuffleClientImpl extends 
ShuffleClientImpl {
               location.getUniqueId());
           logger.debug("RegionFinish for location {}.", location);
           TransportClient client = 
createClientWaitingInFlightRequest(location, mapKey, pushState);
-          RegionFinish regionFinish =
-              new RegionFinish(PRIMARY_MODE, shuffleKey, 
location.getUniqueId(), attemptId);
-          client.sendRpcSync(regionFinish.toByteBuffer(), 
conf.pushDataTimeoutMs());
+          client.sendRpcSync(
+              new TransportMessage(
+                      MessageType.REGION_FINISH,
+                      PbRegionFinish.newBuilder()
+                          .setMode(Mode.forNumber(PRIMARY_MODE))
+                          .setShuffleKey(shuffleKey)
+                          .setPartitionUniqueId(location.getUniqueId())
+                          .setAttemptId(attemptId)
+                          .build()
+                          .toByteArray())
+                  .toByteBuffer(),
+              conf.pushDataTimeoutMs());
           return null;
         });
   }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java
index 163fcaeb9..dc8c04816 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java
@@ -20,6 +20,7 @@ package org.apache.celeborn.common.network.protocol;
 import com.google.common.base.Objects;
 import io.netty.buffer.ByteBuf;
 
+@Deprecated
 public final class PushDataHandShake extends RequestMessage {
   // 0 for primary, 1 for replica, see PartitionLocation.Mode
   public final byte mode;
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java
index 62fe18e53..c7f804d84 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java
@@ -20,6 +20,7 @@ package org.apache.celeborn.common.network.protocol;
 import com.google.common.base.Objects;
 import io.netty.buffer.ByteBuf;
 
+@Deprecated
 public final class RegionFinish extends RequestMessage {
 
   // 0 for primary, 1 for replica, see PartitionLocation.Mode
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java
index 322029d28..4081c5cf4 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java
@@ -20,6 +20,7 @@ package org.apache.celeborn.common.network.protocol;
 import com.google.common.base.Objects;
 import io.netty.buffer.ByteBuf;
 
+@Deprecated
 public final class RegionStart extends RequestMessage {
 
   // 0 for primary, 1 for replica, see PartitionLocation.Mode
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index d72bfec1b..d59bcaab5 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -18,6 +18,9 @@
 package org.apache.celeborn.common.network.protocol;
 
 import static 
org.apache.celeborn.common.protocol.MessageType.OPEN_STREAM_VALUE;
+import static 
org.apache.celeborn.common.protocol.MessageType.PUSH_DATA_HAND_SHAKE_VALUE;
+import static 
org.apache.celeborn.common.protocol.MessageType.REGION_FINISH_VALUE;
+import static 
org.apache.celeborn.common.protocol.MessageType.REGION_START_VALUE;
 import static 
org.apache.celeborn.common.protocol.MessageType.STREAM_HANDLER_VALUE;
 
 import java.io.Serializable;
@@ -31,6 +34,9 @@ import org.slf4j.LoggerFactory;
 import org.apache.celeborn.common.exception.CelebornIOException;
 import org.apache.celeborn.common.protocol.MessageType;
 import org.apache.celeborn.common.protocol.PbOpenStream;
+import org.apache.celeborn.common.protocol.PbPushDataHandShake;
+import org.apache.celeborn.common.protocol.PbRegionFinish;
+import org.apache.celeborn.common.protocol.PbRegionStart;
 import org.apache.celeborn.common.protocol.PbStreamHandler;
 
 public class TransportMessage implements Serializable {
@@ -64,6 +70,12 @@ public class TransportMessage implements Serializable {
         return (T) PbOpenStream.parseFrom(payload);
       case STREAM_HANDLER_VALUE:
         return (T) PbStreamHandler.parseFrom(payload);
+      case PUSH_DATA_HAND_SHAKE_VALUE:
+        return (T) PbPushDataHandShake.parseFrom(payload);
+      case REGION_START_VALUE:
+        return (T) PbRegionStart.parseFrom(payload);
+      case REGION_FINISH_VALUE:
+        return (T) PbRegionFinish.parseFrom(payload);
       default:
         logger.error("Unexpected type {}", type);
     }
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index 2e55c73b1..67db3e4c1 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -76,6 +76,9 @@ enum MessageType {
   CHECK_WORKERS_AVAILABLE = 53;
   CHECK_WORKERS_AVAILABLE_RESPONSE = 54;
   REMOVE_WORKERS_UNAVAILABLE_INFO = 55;
+  PUSH_DATA_HAND_SHAKE = 56;
+  REGION_START = 57;
+  REGION_FINISH = 58;
 }
 
 message PbStorageInfo {
@@ -499,4 +502,29 @@ message PbStreamHandler {
   int32 numChunks = 2;
   repeated int64 chunkOffsets = 3 ;
   string fullPath = 4;
-}
\ No newline at end of file
+}
+
+message PbPushDataHandShake {
+  PbPartitionLocation.Mode mode = 1;
+  string shuffleKey = 2;
+  string partitionUniqueId = 3;
+  int32 attemptId = 4;
+  int32 numPartitions = 5;
+  int32 bufferSize = 6;
+}
+
+message PbRegionStart {
+  PbPartitionLocation.Mode mode = 1;
+  string shuffleKey = 2;
+  string partitionUniqueId = 3;
+  int32 attemptId = 4;
+  int32 currentRegionIndex = 5;
+  bool isBroadcast = 6;
+}
+
+message PbRegionFinish {
+  PbPartitionLocation.Mode mode = 1;
+  string shuffleKey = 2;
+  string partitionUniqueId = 3;
+  int32 attemptId = 4;
+}
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index 68facb601..0a201439a 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, 
ThreadPoolExecutor}
 import java.util.concurrent.atomic.{AtomicBoolean, AtomicIntegerArray}
 
 import com.google.common.base.Throwables
+import com.google.protobuf.GeneratedMessageV3
 import io.netty.buffer.ByteBuf
 
 import org.apache.celeborn.common.exception.{AlreadyClosedException, 
CelebornIOException}
@@ -30,10 +31,11 @@ import org.apache.celeborn.common.meta.{DiskStatus, 
WorkerInfo, WorkerPartitionL
 import org.apache.celeborn.common.metrics.source.Source
 import org.apache.celeborn.common.network.buffer.{NettyManagedBuffer, 
NioManagedBuffer}
 import org.apache.celeborn.common.network.client.{RpcResponseCallback, 
TransportClient, TransportClientFactory}
-import org.apache.celeborn.common.network.protocol.{Message, PushData, 
PushDataHandShake, PushMergedData, RegionFinish, RegionStart, RequestMessage, 
RpcFailure, RpcRequest, RpcResponse}
+import org.apache.celeborn.common.network.protocol.{Message, PushData, 
PushDataHandShake, PushMergedData, RegionFinish, RegionStart, RequestMessage, 
RpcFailure, RpcRequest, RpcResponse, TransportMessage}
 import org.apache.celeborn.common.network.protocol.Message.Type
 import org.apache.celeborn.common.network.server.BaseMessageHandler
-import org.apache.celeborn.common.protocol.{PartitionLocation, 
PartitionSplitMode, PartitionType}
+import org.apache.celeborn.common.protocol.{PartitionLocation, 
PartitionSplitMode, PartitionType, PbPushDataHandShake, PbRegionFinish, 
PbRegionStart}
+import org.apache.celeborn.common.protocol.PbPartitionLocation.Mode
 import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.unsafe.Platform
 import org.apache.celeborn.common.util.Utils
@@ -798,10 +800,8 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
         pushData.`type`(),
         shuffleKey,
         pushData.partitionUniqueId,
-        null,
         location,
-        callback,
-        wrappedCallback)) return
+        callback)) return
 
     val fileWriter =
       getFileWriterAndCheck(pushData.`type`(), location, isPrimary, callback) 
match {
@@ -849,45 +849,120 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
   }
 
   private def handleRpcRequest(client: TransportClient, rpcRequest: 
RpcRequest): Unit = {
-    val msg = Message.decode(rpcRequest.body().nioByteBuffer())
     val requestId = rpcRequest.requestId
-    val (mode, shuffleKey, partitionUniqueId, checkSplit) = msg match {
-      case p: PushDataHandShake => (p.mode, p.shuffleKey, p.partitionUniqueId, 
true)
-      case rs: RegionStart => (rs.mode, rs.shuffleKey, rs.partitionUniqueId, 
true)
-      case rf: RegionFinish => (rf.mode, rf.shuffleKey, rf.partitionUniqueId, 
false)
-    }
+    val (pbMsg, msg, isLegacy, messageType, mode, shuffleKey, 
partitionUniqueId, checkSplit) =
+      mapPartitionRpcRequest(rpcRequest)
     handleCore(
       client,
       rpcRequest,
       requestId,
       () =>
         handleMapPartitionRpcRequestCore(
-          mode,
+          requestId,
+          pbMsg,
           msg,
+          isLegacy,
+          messageType,
+          mode,
           shuffleKey,
           partitionUniqueId,
-          requestId,
           checkSplit,
           new SimpleRpcResponseCallback(
             client,
             requestId,
             shuffleKey)))
+  }
 
+  private def mapPartitionRpcRequest(rpcRequest: RpcRequest)
+      : Tuple8[GeneratedMessageV3, Message, Boolean, Type, Mode, String, 
String, Boolean] = {
+    try {
+      val msg = TransportMessage.fromByteBuffer(
+        
rpcRequest.body().nioByteBuffer()).getParsedPayload.asInstanceOf[GeneratedMessageV3]
+      msg match {
+        case p: PbPushDataHandShake =>
+          (
+            msg,
+            null,
+            false,
+            Type.PUSH_DATA_HAND_SHAKE,
+            p.getMode,
+            p.getShuffleKey,
+            p.getPartitionUniqueId,
+            true)
+        case rs: PbRegionStart =>
+          (
+            msg,
+            null,
+            false,
+            Type.REGION_START,
+            rs.getMode,
+            rs.getShuffleKey,
+            rs.getPartitionUniqueId,
+            true)
+        case rf: PbRegionFinish =>
+          (
+            msg,
+            null,
+            false,
+            Type.REGION_FINISH,
+            rf.getMode,
+            rf.getShuffleKey,
+            rf.getPartitionUniqueId,
+            false)
+      }
+    } catch {
+      case _: Exception =>
+        val msg = Message.decode(rpcRequest.body().nioByteBuffer())
+        msg match {
+          case p: PushDataHandShake =>
+            (
+              null,
+              msg,
+              true,
+              Type.PUSH_DATA_HAND_SHAKE,
+              Mode.forNumber(p.mode),
+              p.shuffleKey,
+              p.partitionUniqueId,
+              true)
+          case rs: RegionStart =>
+            (
+              null,
+              msg,
+              true,
+              Type.REGION_START,
+              Mode.forNumber(rs.mode),
+              rs.shuffleKey,
+              rs.partitionUniqueId,
+              true)
+          case rf: RegionFinish =>
+            (
+              null,
+              msg,
+              true,
+              Type.REGION_FINISH,
+              Mode.forNumber(rf.mode),
+              rf.shuffleKey,
+              rf.partitionUniqueId,
+              false)
+        }
+    }
   }
 
   private def handleMapPartitionRpcRequestCore(
-      mode: Byte,
-      message: Message,
+      requestId: Long,
+      pbMsg: GeneratedMessageV3,
+      msg: Message,
+      isLegacy: Boolean,
+      messageType: Message.Type,
+      mode: Mode,
       shuffleKey: String,
       partitionUniqueId: String,
-      requestId: Long,
       checkSplit: Boolean,
       callback: RpcResponseCallback): Unit = {
-    val isPrimary = PartitionLocation.getMode(mode) == 
PartitionLocation.Mode.PRIMARY
-    val messageType = message.`type`()
     log.debug(
       s"requestId:$requestId, pushdata rpc:$messageType, mode:$mode, 
shuffleKey:$shuffleKey, " +
         s"partitionUniqueId:$partitionUniqueId")
+    val isPrimary = mode == Mode.Primary
     val (workerSourcePrimary, workerSourceReplica) =
       messageType match {
         case Type.PUSH_DATA_HAND_SHAKE =>
@@ -924,10 +999,8 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
         messageType,
         shuffleKey,
         partitionUniqueId,
-        null,
         location,
-        callback,
-        wrappedCallback)) return
+        callback)) return
 
     val fileWriter =
       getFileWriterAndCheck(messageType, location, isPrimary, callback) match {
@@ -957,13 +1030,31 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
     try {
       messageType match {
         case Type.PUSH_DATA_HAND_SHAKE =>
+          val (numPartitions, bufferSize) =
+            if (isLegacy)
+              (
+                msg.asInstanceOf[PushDataHandShake].numPartitions,
+                msg.asInstanceOf[PushDataHandShake].bufferSize)
+            else
+              (
+                pbMsg.asInstanceOf[PbPushDataHandShake].getNumPartitions,
+                pbMsg.asInstanceOf[PbPushDataHandShake].getBufferSize)
           fileWriter.asInstanceOf[MapPartitionFileWriter].pushDataHandShake(
-            message.asInstanceOf[PushDataHandShake].numPartitions,
-            message.asInstanceOf[PushDataHandShake].bufferSize)
+            numPartitions,
+            bufferSize)
         case Type.REGION_START =>
+          val (currentRegionIndex, isBroadcast) =
+            if (isLegacy)
+              (
+                msg.asInstanceOf[RegionStart].currentRegionIndex,
+                msg.asInstanceOf[RegionStart].isBroadcast)
+            else
+              (
+                pbMsg.asInstanceOf[PbRegionStart].getCurrentRegionIndex,
+                Boolean.box(pbMsg.asInstanceOf[PbRegionStart].getIsBroadcast))
           fileWriter.asInstanceOf[MapPartitionFileWriter].regionStart(
-            message.asInstanceOf[RegionStart].currentRegionIndex,
-            message.asInstanceOf[RegionStart].isBroadcast)
+            currentRegionIndex,
+            isBroadcast)
         case Type.REGION_FINISH =>
           fileWriter.asInstanceOf[MapPartitionFileWriter].regionFinish()
         case _ => throw new IllegalArgumentException(s"Not support 
$messageType yet")
@@ -1039,10 +1130,8 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
       messageType: Message.Type,
       shuffleKey: String,
       partitionUniqueId: String,
-      body: ByteBuf,
       location: PartitionLocation,
-      callback: RpcResponseCallback,
-      wrappedCallback: RpcResponseCallback): Boolean = {
+      callback: RpcResponseCallback): Boolean = {
     if (location == null) {
       val msg =
         s"Partition location wasn't found for task(shuffle $shuffleKey, 
uniqueId $partitionUniqueId)."

Reply via email to