This is an automated email from the ASF dual-hosted git repository.

ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 30850a658 [CELEBORN-1932][CIP-14] Adapt java's serialization to 
support cpp serialization for GetReducerFileGroup/Response
30850a658 is described below

commit 30850a6586abe319d96869f1da028defd8e9546f
Author: HolyLow <[email protected]>
AuthorDate: Thu Apr 3 14:43:47 2025 +0800

    [CELEBORN-1932][CIP-14] Adapt java's serialization to support cpp 
serialization for GetReducerFileGroup/Response
    
    ### What changes were proposed in this pull request?
    The java's existing serialization is adapted to support multi-language 
serialization. Besides, the GetReducerFileGroup/Response is adapted to java/cpp 
modes.
    
    ### Why are the changes needed?
    To support CppClient communicates with JavaServer.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Compilation.
    
    Closes #3177 from HolyLow/issue/celeborn-1932-adapt-java-serialization.
    
    Authored-by: HolyLow <[email protected]>
    Signed-off-by: mingji <[email protected]>
---
 .../apache/celeborn/client/ShuffleClientImpl.java  |  8 ++---
 .../org/apache/celeborn/client/CommitManager.scala |  8 +++--
 .../apache/celeborn/client/LifecycleManager.scala  | 17 +++++----
 .../celeborn/client/commit/CommitHandler.scala     |  6 +++-
 .../client/commit/MapPartitionCommitHandler.scala  |  9 +++--
 .../commit/ReducePartitionCommitHandler.scala      | 42 ++++++++++++++++------
 .../celeborn/client/ShuffleClientSuiteJ.java       | 25 ++++++++-----
 .../common/network/protocol/SerdeVersion.java      | 41 +++++++++++++++++++++
 .../common/network/protocol/TransportMessage.java  | 17 ++++++++-
 .../common/protocol/message/ControlMessages.scala  | 22 +++++++-----
 .../celeborn/common/rpc/netty/NettyRpcEnv.scala    | 16 ++++++++-
 .../common/serializer/JavaSerializer.scala         | 31 ++++++++++++++++
 12 files changed, 198 insertions(+), 44 deletions(-)

diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index e7f4a9083..12ef7d330 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -53,10 +53,8 @@ import 
org.apache.celeborn.common.network.client.RpcResponseCallback;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportClientBootstrap;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
-import org.apache.celeborn.common.network.protocol.PushData;
-import org.apache.celeborn.common.network.protocol.PushMergedData;
-import org.apache.celeborn.common.network.protocol.TransportMessage;
-import org.apache.celeborn.common.network.protocol.TransportMessagesHelper;
+import org.apache.celeborn.common.network.protocol.*;
+import org.apache.celeborn.common.network.protocol.SerdeVersion;
 import org.apache.celeborn.common.network.sasl.SaslClientBootstrap;
 import org.apache.celeborn.common.network.sasl.SaslCredentials;
 import org.apache.celeborn.common.network.server.BaseMessageHandler;
@@ -1803,7 +1801,7 @@ public class ShuffleClientImpl extends ShuffleClient {
     }
     try {
       GetReducerFileGroup getReducerFileGroup =
-          new GetReducerFileGroup(shuffleId, isSegmentGranularityVisible);
+          new GetReducerFileGroup(shuffleId, isSegmentGranularityVisible, 
SerdeVersion.V1);
 
       GetReducerFileGroupResponse response =
           lifecycleManagerRef.askSync(
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index ad0d66e3a..bffd05430 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -34,6 +34,7 @@ import org.apache.celeborn.client.listener.{WorkersStatus, 
WorkerStatusListener}
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.meta.WorkerInfo
+import org.apache.celeborn.common.network.protocol.SerdeVersion
 import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType, 
StorageInfo}
 import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.rpc.RpcCallContext
@@ -275,8 +276,11 @@ class CommitManager(appUniqueId: String, val conf: 
CelebornConf, lifecycleManage
     getCommitHandler(shuffleId).waitStageEnd(shuffleId)
   }
 
-  def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: Int): Unit 
= {
-    getCommitHandler(shuffleId).handleGetReducerFileGroup(context, shuffleId)
+  def handleGetReducerFileGroup(
+      context: RpcCallContext,
+      shuffleId: Int,
+      serdeVersion: SerdeVersion): Unit = {
+    getCommitHandler(shuffleId).handleGetReducerFileGroup(context, shuffleId, 
serdeVersion)
   }
 
   // exposed for test
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 35d1945e1..9eac9afb8 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -45,7 +45,7 @@ import org.apache.celeborn.common.identity.{IdentityProvider, 
UserIdentifier}
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.meta.{ApplicationMeta, 
ShufflePartitionLocationInfo, WorkerInfo}
 import org.apache.celeborn.common.metrics.source.Role
-import org.apache.celeborn.common.network.protocol.TransportMessagesHelper
+import org.apache.celeborn.common.network.protocol.{SerdeVersion, 
TransportMessagesHelper}
 import org.apache.celeborn.common.network.sasl.registration.RegistrationInfo
 import org.apache.celeborn.common.protocol._
 import org.apache.celeborn.common.protocol.RpcNameConstants.WORKER_EP
@@ -432,10 +432,13 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
           throw new UnsupportedOperationException(s"Not support $partitionType 
yet")
       }
 
-    case GetReducerFileGroup(shuffleId: Int, isSegmentGranularityVisible: 
Boolean) =>
+    case GetReducerFileGroup(
+          shuffleId: Int,
+          isSegmentGranularityVisible: Boolean,
+          serdeVersion: SerdeVersion) =>
       logDebug(
         s"Received GetShuffleFileGroup request for shuffleId $shuffleId, 
isSegmentGranularityVisible $isSegmentGranularityVisible")
-      handleGetReducerFileGroup(context, shuffleId, 
isSegmentGranularityVisible)
+      handleGetReducerFileGroup(context, shuffleId, 
isSegmentGranularityVisible, serdeVersion)
 
     case pb: PbGetShuffleId =>
       val appShuffleId = pb.getAppShuffleId
@@ -845,7 +848,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
   private def handleGetReducerFileGroup(
       context: RpcCallContext,
       shuffleId: Int,
-      isSegmentGranularityVisible: Boolean): Unit = {
+      isSegmentGranularityVisible: Boolean,
+      serdeVersion: SerdeVersion): Unit = {
     // If isSegmentGranularityVisible is set to true, the downstream reduce 
task may start early than upstream map task, e.g. flink hybrid shuffle.
     // Under these circumstances, there's a possibility that the shuffle might 
not yet be registered when the downstream reduce task send GetReduceFileGroup 
request,
     // so we shouldn't send a SHUFFLE_NOT_REGISTERED response directly, should 
enqueue this request to pending list, and response to the downstream reduce 
task the ReduceFileGroup when the upstream map task register shuffle done
@@ -854,10 +858,11 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
       context.reply(GetReducerFileGroupResponse(
         StatusCode.SHUFFLE_NOT_REGISTERED,
         JavaUtils.newConcurrentHashMap(),
-        Array.empty))
+        Array.empty,
+        serdeVersion = serdeVersion))
       return
     }
-    commitManager.handleGetReducerFileGroup(context, shuffleId)
+    commitManager.handleGetReducerFileGroup(context, shuffleId, serdeVersion)
   }
 
   private def handleGetShuffleIdForApp(
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
index ea86b828d..658eb47a7 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -34,6 +34,7 @@ import 
org.apache.celeborn.client.LifecycleManager.{ShuffleFailedWorkers, Shuffl
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, 
WorkerInfo}
+import org.apache.celeborn.common.network.protocol.SerdeVersion
 import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.{CommitFiles, 
CommitFilesResponse}
 import org.apache.celeborn.common.protocol.message.StatusCode
@@ -178,7 +179,10 @@ abstract class CommitHandler(
    * partitions are complete by the time the method is called, as downstream 
tasks may start early before all tasks
    * are completed.So map partition may need refresh reducer file group if 
needed.
    */
-  def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: Int): Unit
+  def handleGetReducerFileGroup(
+      context: RpcCallContext,
+      shuffleId: Int,
+      serdeVersion: SerdeVersion): Unit
 
   def removeExpiredShuffle(shuffleId: Int): Unit = {
     reducerFileGroupsMap.remove(shuffleId)
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
index a08f1e0d5..4f31018e5 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
@@ -31,6 +31,7 @@ import 
org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, Shu
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, 
WorkerInfo}
+import org.apache.celeborn.common.network.protocol.SerdeVersion
 import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
 import org.apache.celeborn.common.protocol.message.StatusCode
@@ -230,7 +231,10 @@ class MapPartitionCommitHandler(
     shuffleIsSegmentGranularityVisible.get(shuffleId)
   }
 
-  override def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: 
Int): Unit = {
+  override def handleGetReducerFileGroup(
+      context: RpcCallContext,
+      shuffleId: Int,
+      serdeVersion: SerdeVersion): Unit = {
     // TODO: if support the downstream map task start early before the 
upstream reduce task, it should
     //  waiting the upstream task register shuffle, then reply these 
GetReducerFileGroup.
     //  Note that flink hybrid shuffle should support it in the future.
@@ -244,7 +248,8 @@ class MapPartitionCommitHandler(
       StatusCode.SUCCESS,
       reducerFileGroupsMap.getOrDefault(shuffleId, 
JavaUtils.newConcurrentHashMap()),
       getMapperAttempts(shuffleId),
-      succeedPartitionIds))
+      succeedPartitionIds,
+      serdeVersion = serdeVersion))
   }
 
   override def releasePartitionResource(shuffleId: Int, partitionId: Int): 
Unit = {
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 98fe624fb..5bdd1c550 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -34,6 +34,7 @@ import 
org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, Shu
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.meta.ShufflePartitionLocationInfo
+import org.apache.celeborn.common.network.protocol.SerdeVersion
 import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
 import org.apache.celeborn.common.protocol.message.StatusCode
@@ -65,8 +66,10 @@ class ReducePartitionCommitHandler(
     sharedRpcPool)
   with Logging {
 
+  class MultiSerdeVersionRpcContext(val ctx: RpcCallContext, val serdeVersion: 
SerdeVersion) {}
+
   private val getReducerFileGroupRequest =
-    JavaUtils.newConcurrentHashMap[Int, util.Set[RpcCallContext]]()
+    JavaUtils.newConcurrentHashMap[Int, 
util.Set[MultiSerdeVersionRpcContext]]()
   private val dataLostShuffleSet = ConcurrentHashMap.newKeySet[Int]()
   private val stageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
   private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
@@ -300,7 +303,7 @@ class ReducePartitionCommitHandler(
       numMappers: Int,
       isSegmentGranularityVisible: Boolean): Unit = {
     super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible)
-    getReducerFileGroupRequest.put(shuffleId, new 
util.HashSet[RpcCallContext]())
+    getReducerFileGroupRequest.put(shuffleId, new 
util.HashSet[MultiSerdeVersionRpcContext]())
     initMapperAttempts(shuffleId, numMappers)
   }
 
@@ -314,7 +317,16 @@ class ReducePartitionCommitHandler(
     }
   }
 
-  private def replyGetReducerFileGroup(context: RpcCallContext, shuffleId: 
Int): Unit = {
+  private def replyGetReducerFileGroup(
+      context: MultiSerdeVersionRpcContext,
+      shuffleId: Int): Unit = {
+    replyGetReducerFileGroup(context.ctx, shuffleId, context.serdeVersion)
+  }
+
+  private def replyGetReducerFileGroup(
+      context: RpcCallContext,
+      shuffleId: Int,
+      serdeVersion: SerdeVersion): Unit = {
     if (isStageDataLost(shuffleId)) {
       context.reply(
         GetReducerFileGroupResponse(
@@ -328,7 +340,8 @@ class ReducePartitionCommitHandler(
         var response = GetReducerFileGroupResponse(
           StatusCode.SUCCESS,
           reducerFileGroupsMap.getOrDefault(shuffleId, 
JavaUtils.newConcurrentHashMap()),
-          getMapperAttempts(shuffleId))
+          getMapperAttempts(shuffleId),
+          serdeVersion = serdeVersion)
 
         // only check whether broadcast enabled for the UTs
         if (getReducerFileGroupResponseBroadcastEnabled) {
@@ -348,7 +361,8 @@ class ReducePartitionCommitHandler(
                 pushFailedBatches =
                   shufflePushFailedBatches.getOrDefault(
                     shuffleId,
-                    new util.HashMap[String, util.Set[PushFailedBatch]]()))
+                    new util.HashMap[String, util.Set[PushFailedBatch]]()),
+                serdeVersion = serdeVersion)
 
               val serializedMsg =
                 
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
@@ -382,22 +396,30 @@ class ReducePartitionCommitHandler(
       response: GetReducerFileGroupResponse): GetReducerFileGroupResponse = {
     lifecycleManager.broadcastGetReducerFileGroupResponse(shuffleId, response) 
match {
       case Some(broadcastBytes) if broadcastBytes.nonEmpty =>
-        GetReducerFileGroupResponse(response.status, broadcast = 
broadcastBytes)
+        GetReducerFileGroupResponse(
+          response.status,
+          broadcast = broadcastBytes,
+          serdeVersion = response.serdeVersion)
       case _ => response
     }
   }
 
-  override def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: 
Int): Unit = {
+  override def handleGetReducerFileGroup(
+      context: RpcCallContext,
+      shuffleId: Int,
+      serdeVersion: SerdeVersion): Unit = {
     // Quick return for ended stage, avoid occupy sync lock.
     if (isStageEnd(shuffleId)) {
-      replyGetReducerFileGroup(context, shuffleId)
+      replyGetReducerFileGroup(context, shuffleId, serdeVersion)
     } else {
       getReducerFileGroupRequest.synchronized {
         // If setStageEnd() called after isStageEnd and before got lock, 
should reply here.
         if (isStageEnd(shuffleId)) {
-          replyGetReducerFileGroup(context, shuffleId)
+          replyGetReducerFileGroup(context, shuffleId, serdeVersion)
         } else {
-          getReducerFileGroupRequest.get(shuffleId).add(context)
+          getReducerFileGroupRequest.get(shuffleId).add(new 
MultiSerdeVersionRpcContext(
+            context,
+            serdeVersion))
         }
       }
     }
diff --git 
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java 
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
index a5076a59f..85cf0ba10 100644
--- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
+++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
@@ -48,6 +48,7 @@ import 
org.apache.celeborn.common.exception.CelebornIOException;
 import org.apache.celeborn.common.identity.UserIdentifier;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
+import org.apache.celeborn.common.network.protocol.SerdeVersion;
 import org.apache.celeborn.common.protocol.CompressionCodec;
 import org.apache.celeborn.common.protocol.PartitionLocation;
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse$;
@@ -428,7 +429,8 @@ public class ShuffleClientSuiteJ {
                   new int[0],
                   Collections.emptySet(),
                   Collections.emptyMap(),
-                  new byte[0]);
+                  new byte[0],
+                  SerdeVersion.V1);
             });
 
     when(endpointRef.askSync(any(), any(), any(Integer.class), 
any(Long.class), any()))
@@ -441,7 +443,8 @@ public class ShuffleClientSuiteJ {
                   new int[0],
                   Collections.emptySet(),
                   Collections.emptyMap(),
-                  new byte[0]);
+                  new byte[0],
+                  SerdeVersion.V1);
             });
 
     shuffleClient =
@@ -485,7 +488,8 @@ public class ShuffleClientSuiteJ {
                   new int[0],
                   Collections.emptySet(),
                   Collections.emptyMap(),
-                  new byte[0]);
+                  new byte[0],
+                  SerdeVersion.V1);
             });
 
     when(endpointRef.askSync(any(), any(), any(Integer.class), 
any(Long.class), any()))
@@ -497,7 +501,8 @@ public class ShuffleClientSuiteJ {
                   new int[0],
                   Collections.emptySet(),
                   Collections.emptyMap(),
-                  new byte[0]);
+                  new byte[0],
+                  SerdeVersion.V1);
             });
 
     shuffleClient =
@@ -519,7 +524,8 @@ public class ShuffleClientSuiteJ {
                   new int[0],
                   Collections.emptySet(),
                   Collections.emptyMap(),
-                  new byte[0]);
+                  new byte[0],
+                  SerdeVersion.V1);
             });
 
     when(endpointRef.askSync(any(), any(), any(Integer.class), 
any(Long.class), any()))
@@ -531,7 +537,8 @@ public class ShuffleClientSuiteJ {
                   new int[0],
                   Collections.emptySet(),
                   Collections.emptyMap(),
-                  new byte[0]);
+                  new byte[0],
+                  SerdeVersion.V1);
             });
 
     shuffleClient =
@@ -553,7 +560,8 @@ public class ShuffleClientSuiteJ {
                   new int[0],
                   Collections.emptySet(),
                   Collections.emptyMap(),
-                  new byte[0]);
+                  new byte[0],
+                  SerdeVersion.V1);
             });
 
     when(endpointRef.askSync(any(), any(), any(Integer.class), 
any(Long.class), any()))
@@ -565,7 +573,8 @@ public class ShuffleClientSuiteJ {
                   new int[0],
                   Collections.emptySet(),
                   Collections.emptyMap(),
-                  new byte[0]);
+                  new byte[0],
+                  SerdeVersion.V1);
             });
 
     shuffleClient =
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/SerdeVersion.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SerdeVersion.java
new file mode 100644
index 000000000..177b9bb7d
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SerdeVersion.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.protocol;
+
+/**
+ * SerdeVersion represents which ser/de version the message is deserialized 
from / will be
+ * serialized into. For V1 (used by legacy java engine), the ser/de is 
dependent on java's
+ * serialization stack, and the leading byte would be 0xAC according to Java's 
serialization stack.
+ * For V2 (used by cpp client), the ser/de is language-agnostic, the leading 
byte would be 0xFF as
+ * defined in CelebornCpp module. In this way, messages from/for different 
version could be
+ * distinguished and ser/de accordingly.
+ */
+public enum SerdeVersion {
+  V1((byte) 0xAC),
+  V2((byte) 0xFF);
+
+  private final byte marker;
+
+  SerdeVersion(byte marker) {
+    this.marker = marker;
+  }
+
+  public byte getMarker() {
+    return marker;
+  }
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index 01a9a37f9..137c2e710 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -36,11 +36,17 @@ public class TransportMessage implements Serializable {
   @Deprecated private final transient MessageType type;
   private final int messageTypeValue;
   private final byte[] payload;
+  private final SerdeVersion serdeVersion;
 
   public TransportMessage(MessageType type, byte[] payload) {
+    this(type, payload, SerdeVersion.V1);
+  }
+
+  public TransportMessage(MessageType type, byte[] payload, SerdeVersion 
serdeVersion) {
     this.type = type;
     this.messageTypeValue = type.getNumber();
     this.payload = payload;
+    this.serdeVersion = serdeVersion;
   }
 
   public MessageType getType() {
@@ -55,6 +61,10 @@ public class TransportMessage implements Serializable {
     return payload;
   }
 
+  public SerdeVersion getSerdeVersion() {
+    return serdeVersion;
+  }
+
   public <T extends GeneratedMessageV3> T getParsedPayload() throws 
InvalidProtocolBufferException {
     switch (messageTypeValue) {
       case OPEN_STREAM_VALUE:
@@ -132,6 +142,11 @@ public class TransportMessage implements Serializable {
   }
 
   public static TransportMessage fromByteBuffer(ByteBuffer buffer) throws 
CelebornIOException {
+    return fromByteBuffer(buffer, SerdeVersion.V1);
+  }
+
+  public static TransportMessage fromByteBuffer(ByteBuffer buffer, 
SerdeVersion serdeVersion)
+      throws CelebornIOException {
     int messageTypeValue = buffer.getInt();
     if (MessageType.forNumber(messageTypeValue) == null) {
       throw new CelebornIOException("Decode failed, fallback to legacy 
messages.");
@@ -140,6 +155,6 @@ public class TransportMessage implements Serializable {
     byte[] payload = new byte[payloadLen];
     buffer.get(payload);
     MessageType msgType = MessageType.forNumber(messageTypeValue);
-    return new TransportMessage(msgType, payload);
+    return new TransportMessage(msgType, payload, serdeVersion);
   }
 }
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 949b13322..8719dea7a 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -28,7 +28,7 @@ import org.roaringbitmap.RoaringBitmap
 import org.apache.celeborn.common.identity.UserIdentifier
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.meta.{DiskInfo, WorkerInfo, WorkerStatus}
-import org.apache.celeborn.common.network.protocol.TransportMessage
+import org.apache.celeborn.common.network.protocol.{SerdeVersion, 
TransportMessage}
 import org.apache.celeborn.common.protocol._
 import org.apache.celeborn.common.protocol.MessageType._
 import org.apache.celeborn.common.quota.ResourceConsumption
@@ -279,7 +279,10 @@ object ControlMessages extends Logging {
 
   case class MapperEndResponse(status: StatusCode) extends MasterMessage
 
-  case class GetReducerFileGroup(shuffleId: Int, isSegmentGranularityVisible: 
Boolean)
+  case class GetReducerFileGroup(
+      shuffleId: Int,
+      isSegmentGranularityVisible: Boolean,
+      serdeVersion: SerdeVersion)
     extends MasterMessage
 
   // util.Set[String] -> util.Set[Path.toString]
@@ -290,7 +293,8 @@ object ControlMessages extends Logging {
       attempts: Array[Int] = Array.emptyIntArray,
       partitionIds: util.Set[Integer] = Collections.emptySet[Integer](),
       pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = 
Collections.emptyMap(),
-      broadcast: Array[Byte] = Array.emptyByteArray)
+      broadcast: Array[Byte] = Array.emptyByteArray,
+      serdeVersion: SerdeVersion = SerdeVersion.V1)
     extends MasterMessage
 
   object WorkerExclude {
@@ -747,12 +751,12 @@ object ControlMessages extends Logging {
         .build().toByteArray
       new TransportMessage(MessageType.MAPPER_END_RESPONSE, payload)
 
-    case GetReducerFileGroup(shuffleId, isSegmentGranularityVisible) =>
+    case GetReducerFileGroup(shuffleId, isSegmentGranularityVisible, 
serdeVersion) =>
       val payload = PbGetReducerFileGroup.newBuilder()
         .setShuffleId(shuffleId)
         .setIsSegmentGranularityVisible(isSegmentGranularityVisible)
         .build().toByteArray
-      new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP, payload)
+      new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP, payload, 
serdeVersion)
 
     case GetReducerFileGroupResponse(
           status,
@@ -760,7 +764,8 @@ object ControlMessages extends Logging {
           attempts,
           partitionIds,
           failedBatches,
-          broadcast) =>
+          broadcast,
+          serdeVersion) =>
       val builder = PbGetReducerFileGroupResponse
         .newBuilder()
         .setStatus(status.getValue)
@@ -780,7 +785,7 @@ object ControlMessages extends Logging {
         }.asJava)
       builder.setBroadcast(ByteString.copyFrom(broadcast))
       val payload = builder.build().toByteArray
-      new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE, 
payload)
+      new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE, 
payload, serdeVersion)
 
     case pb: PbWorkerExclude =>
       new TransportMessage(MessageType.WORKER_EXCLUDE, pb.toByteArray)
@@ -1177,7 +1182,8 @@ object ControlMessages extends Logging {
         val pbGetReducerFileGroup = 
PbGetReducerFileGroup.parseFrom(message.getPayload)
         GetReducerFileGroup(
           pbGetReducerFileGroup.getShuffleId,
-          pbGetReducerFileGroup.getIsSegmentGranularityVisible)
+          pbGetReducerFileGroup.getIsSegmentGranularityVisible,
+          message.getSerdeVersion)
 
       case GET_REDUCER_FILE_GROUP_RESPONSE_VALUE =>
         val pbGetReducerFileGroupResponse = PbGetReducerFileGroupResponse
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
index 1ce989134..b2de46559 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
@@ -34,7 +34,7 @@ import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.network.TransportContext
 import org.apache.celeborn.common.network.client._
-import org.apache.celeborn.common.network.protocol.{RequestMessage => 
NRequestMessage, RpcRequest}
+import org.apache.celeborn.common.network.protocol.{RequestMessage => 
NRequestMessage, RpcRequest, SerdeVersion, TransportMessage}
 import org.apache.celeborn.common.network.sasl.{SaslClientBootstrap, 
SaslServerBootstrap}
 import 
org.apache.celeborn.common.network.sasl.registration.{RegistrationClientBootstrap,
 RegistrationServerBootstrap}
 import org.apache.celeborn.common.network.server._
@@ -504,6 +504,20 @@ private[celeborn] class RequestMessage(
       writeRpcAddress(out, senderAddress)
       writeRpcAddress(out, receiver.address)
       out.writeUTF(receiver.name)
+      val msg = Utils.toTransportMessage(content)
+      msg match {
+        case transMsg: TransportMessage =>
+          // Check if the msg is a TransportMessage with language-agnostic V2 
serdeVersion.
+          // If so, write the marker and the body explicitly.
+          if (transMsg.getSerdeVersion == SerdeVersion.V2) {
+            val out = new DataOutputStream(bos)
+            out.writeByte(SerdeVersion.V2.getMarker)
+            out.write(transMsg.toByteBuffer.array)
+            out.close()
+            return bos.toByteBuffer
+          }
+        case _ =>
+      }
       val s = nettyEnv.serializeStream(out)
       try {
         s.writeObject(Utils.toTransportMessage(content))
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/serializer/JavaSerializer.scala
 
b/common/src/main/scala/org/apache/celeborn/common/serializer/JavaSerializer.scala
index d38161583..3a813ef9c 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/serializer/JavaSerializer.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/serializer/JavaSerializer.scala
@@ -23,6 +23,7 @@ import java.nio.ByteBuffer
 import scala.reflect.ClassTag
 
 import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.network.protocol.{SerdeVersion, 
TransportMessage}
 import org.apache.celeborn.common.util.{ByteBufferInputStream, 
ByteBufferOutputStream, Utils}
 
 private[celeborn] class JavaSerializationStream(
@@ -98,6 +99,20 @@ private[celeborn] class JavaSerializerInstance(
 
   override def serialize[T: ClassTag](t: T): ByteBuffer = {
     val bos = new ByteBufferOutputStream()
+    val msg = Utils.toTransportMessage(t)
+    msg match {
+      case transMsg: TransportMessage =>
+        // Check if the msg is a TransportMessage with language-agnostic V2 
serdeVersion.
+        // If so, write the marker and the body explicitly.
+        if (transMsg.getSerdeVersion == SerdeVersion.V2) {
+          val out = new DataOutputStream(bos)
+          out.writeByte(SerdeVersion.V2.getMarker)
+          out.write(transMsg.toByteBuffer.array)
+          out.close()
+          return bos.toByteBuffer
+        }
+      case _ =>
+    }
     val out = serializeStream(bos)
     out.writeObject(Utils.toTransportMessage(t))
     out.close()
@@ -105,12 +120,28 @@ private[celeborn] class JavaSerializerInstance(
   }
 
   override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
+    bytes.mark
+    val serdeVersion = bytes.get
+    // If the serdeVersion byte is V2, deserialize directly.
+    if (serdeVersion == SerdeVersion.V2.getMarker) {
+      return Utils.fromTransportMessage(
+        TransportMessage.fromByteBuffer(bytes, 
SerdeVersion.V2)).asInstanceOf[T]
+    }
+    bytes.reset
     val bis = new ByteBufferInputStream(bytes)
     val in = deserializeStream(bis)
     Utils.fromTransportMessage(in.readObject()).asInstanceOf[T]
   }
 
   override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: 
ClassLoader): T = {
+    bytes.mark
+    val serdeVersion = bytes.get
+    // If the serdeVersion byte is V2, deserialize directly.
+    if (serdeVersion == SerdeVersion.V2.getMarker) {
+      return Utils.fromTransportMessage(
+        TransportMessage.fromByteBuffer(bytes, 
SerdeVersion.V2)).asInstanceOf[T]
+    }
+    bytes.reset
     val bis = new ByteBufferInputStream(bytes)
     val in = deserializeStream(bis, loader)
     Utils.fromTransportMessage(in.readObject()).asInstanceOf[T]

Reply via email to