This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new d30c02e36 [CELEBORN-2235][CIP-14] Adapt Java end's serialization to 
CppWriterClient
d30c02e36 is described below

commit d30c02e3690a5ecfe3cbb32f2d21ba101d37c679
Author: HolyLow <[email protected]>
AuthorDate: Mon Jan 5 22:24:22 2026 +0800

    [CELEBORN-2235][CIP-14] Adapt Java end's serialization to CppWriterClient
    
    ### What changes were proposed in this pull request?
    This PR adapts Java end's serialization to CppWriterClient, including 
RegisterShuffle/Response, Revive/Response, MapperEnd/Response. Joint test for 
cpp-write java-read procedure is included as well.
    
    ### Why are the changes needed?
    Support writing to Celeborn server with CppWriterClient.
    
    ### Does this PR resolve a correctness bug?
    No.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Compilation and integration tests.
    
    Closes #3561 from 
HolyLow/issue/celeborn-2235-adapt-java-to-cpp-writer-serialization.
    
    Authored-by: HolyLow <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .github/workflows/cpp_integration.yml              |  13 +-
 .../flink/client/FlinkShuffleClientImpl.java       |   5 +-
 .../apache/celeborn/client/ShuffleClientImpl.java  |  56 +++---
 .../apache/celeborn/client/LifecycleManager.scala  | 114 ++++++-----
 .../client/RequestLocationCallContext.scala        |  17 +-
 .../celeborn/client/ShuffleClientBaseSuiteJ.java   |  12 +-
 .../celeborn/client/ShuffleClientSuiteJ.java       |   4 +-
 .../common/protocol/message/ControlMessages.scala  | 217 +++++++++++++--------
 .../apache/celeborn/common/util/UtilsSuite.scala   |  13 +-
 cpp/celeborn/tests/CMakeLists.txt                  |  27 ++-
 cpp/celeborn/tests/DataSumWithWriterClient.cpp     |  96 +++++++++
 ...Z4.scala => CppWriteJavaReadTestWithNONE.scala} |   4 +-
 ....scala => JavaCppHybridReadWriteTestBase.scala} |  98 +++++++++-
 .../cluster/JavaWriteCppReadTestWithLZ4.scala      |   2 +-
 .../cluster/JavaWriteCppReadTestWithNONE.scala     |   2 +-
 .../cluster/JavaWriteCppReadTestWithZSTD.scala     |   2 +-
 16 files changed, 492 insertions(+), 190 deletions(-)

diff --git a/.github/workflows/cpp_integration.yml 
b/.github/workflows/cpp_integration.yml
index 1c521f5d0..4f04c4bad 100644
--- a/.github/workflows/cpp_integration.yml
+++ b/.github/workflows/cpp_integration.yml
@@ -85,24 +85,31 @@ jobs:
           check-latest: false
       - name: Compile & Install Celeborn Java
         run: build/mvn clean install -DskipTests
-      - name: Run Java-Cpp Hybrid Integration Test
+      - name: Run Java-Write Cpp-Read Hybrid Integration Test (NONE 
Decompression)
         run: |
           build/mvn -pl worker \
             test-compile exec:java \
             -Dexec.classpathScope="test" \
             
-Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.JavaWriteCppReadTestWithNONE"
 \
             -Dexec.args="-XX:MaxDirectMemorySize=2G"
-      - name: Run Java-Cpp Hybrid Integration Test (LZ4 Decompression)
+      - name: Run Java-Write Cpp-Read Hybrid Integration Test (LZ4 
Decompression)
         run: |
           build/mvn -pl worker \
             test-compile exec:java \
             -Dexec.classpathScope="test" \
             
-Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.JavaWriteCppReadTestWithLZ4"
 \
             -Dexec.args="-XX:MaxDirectMemorySize=2G"
-      - name: Run Java-Cpp Hybrid Integration Test (ZSTD Decompression)
+      - name: Run Java-Write Cpp-Read Hybrid Integration Test (ZSTD 
Decompression)
         run: |
           build/mvn -pl worker \
             test-compile exec:java \
             -Dexec.classpathScope="test" \
             
-Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.JavaWriteCppReadTestWithZSTD"
 \
             -Dexec.args="-XX:MaxDirectMemorySize=2G"
+      - name: Run Cpp-Write Java-Read Hybrid Integration Test (NONE 
Compression)
+        run: |
+          build/mvn -pl worker \
+            test-compile exec:java \
+            -Dexec.classpathScope="test" \
+            
-Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.CppWriteJavaReadTestWithNONE"
 \
+            -Dexec.args="-XX:MaxDirectMemorySize=2G"
\ No newline at end of file
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java
index 87a80006a..6f52b4aab 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java
@@ -46,6 +46,7 @@ import 
org.apache.celeborn.common.network.client.RpcResponseCallback;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
 import org.apache.celeborn.common.network.protocol.PushData;
+import org.apache.celeborn.common.network.protocol.SerdeVersion;
 import org.apache.celeborn.common.network.protocol.TransportMessage;
 import org.apache.celeborn.common.network.util.TransportConf;
 import org.apache.celeborn.common.protocol.MessageType;
@@ -528,7 +529,7 @@ public class FlinkShuffleClientImpl extends 
ShuffleClientImpl {
   public Optional<PartitionLocation> revive(
       int shuffleId, int mapId, int attemptId, PartitionLocation location)
       throws CelebornIOException {
-    Set<Integer> mapIds = new HashSet<>();
+    List<Integer> mapIds = new ArrayList<>();
     mapIds.add(mapId);
     List<ReviveRequest> requests = new ArrayList<>();
     ReviveRequest req =
@@ -543,7 +544,7 @@ public class FlinkShuffleClientImpl extends 
ShuffleClientImpl {
     requests.add(req);
     PbChangeLocationResponse response =
         lifecycleManagerRef.askSync(
-            ControlMessages.Revive$.MODULE$.apply(shuffleId, mapIds, requests),
+            ControlMessages.Revive$.MODULE$.apply(shuffleId, mapIds, requests, 
SerdeVersion.V1),
             conf.clientRpcRequestPartitionLocationAskTimeout(),
             ClassTag$.MODULE$.apply(PbChangeLocationResponse.class));
     // per partitionKey only serve single PartitionLocation in Client Cache.
diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index cfe40a296..f6a7d9750 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -550,11 +550,11 @@ public class ShuffleClientImpl extends ShuffleClient {
         numPartitions,
         () ->
             lifecycleManagerRef.askSync(
-                RegisterShuffle$.MODULE$.apply(shuffleId, numMappers, 
numPartitions),
+                new RegisterShuffle(shuffleId, numMappers, numPartitions, 
SerdeVersion.V1),
                 conf.clientRpcRegisterShuffleAskTimeout(),
                 rpcMaxRetries,
                 rpcRetryWait,
-                ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
+                ClassTag$.MODULE$.apply(RegisterShuffleResponse.class)));
   }
 
   @Override
@@ -593,7 +593,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                         partitionId,
                         isSegmentGranularityVisible),
                     conf.clientRpcRegisterShuffleAskTimeout(),
-                    ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
+                    ClassTag$.MODULE$.apply(RegisterShuffleResponse.class)));
 
     return partitionLocationMap.get(partitionId);
   }
@@ -709,23 +709,18 @@ public class ShuffleClientImpl extends ShuffleClient {
   }
 
   private ConcurrentHashMap<Integer, PartitionLocation> 
registerShuffleInternal(
-      int shuffleId,
-      int numMappers,
-      int numPartitions,
-      Callable<PbRegisterShuffleResponse> callable)
+      int shuffleId, int numMappers, int numPartitions, 
Callable<RegisterShuffleResponse> callable)
       throws CelebornIOException {
     int numRetries = registerShuffleMaxRetries;
     StatusCode lastFailedStatusCode = null;
     while (numRetries > 0) {
       try {
-        PbRegisterShuffleResponse response = callable.call();
-        StatusCode respStatus = StatusCode.fromValue(response.getStatus());
+        RegisterShuffleResponse response = callable.call();
+        StatusCode respStatus = response.status();
         if (StatusCode.SUCCESS.equals(respStatus)) {
           ConcurrentHashMap<Integer, PartitionLocation> result = 
JavaUtils.newConcurrentHashMap();
-          Tuple2<List<PartitionLocation>, List<PartitionLocation>> locations =
-              PbSerDeUtils.fromPbPackedPartitionLocationsPair(
-                  response.getPackedPartitionLocationsPair());
-          for (PartitionLocation location : locations._1) {
+          PartitionLocation[] locations = response.partitionLocations();
+          for (PartitionLocation location : locations) {
             pushExcludedWorkers.remove(location.hostAndPushPort());
             if (location.hasPeer()) {
               pushExcludedWorkers.remove(location.getPeer().hostAndPushPort());
@@ -900,43 +895,43 @@ public class ShuffleClientImpl extends ShuffleClient {
       oldLocMap.put(req.partitionId, req.loc);
     }
     try {
-      PbChangeLocationResponse response =
+      ChangeLocationResponse response =
           lifecycleManagerRef.askSync(
-              Revive$.MODULE$.apply(shuffleId, mapIds, requests),
+              Revive$.MODULE$.apply(
+                  shuffleId, new ArrayList<>(mapIds), new 
ArrayList<>(requests), SerdeVersion.V1),
               conf.clientRpcRequestPartitionLocationAskTimeout(),
-              ClassTag$.MODULE$.apply(PbChangeLocationResponse.class));
+              ClassTag$.MODULE$.apply(ChangeLocationResponse.class));
 
-      for (int i = 0; i < response.getEndedMapIdCount(); i++) {
-        int mapId = response.getEndedMapId(i);
+      for (int i = 0; i < response.endedMapIds().size(); i++) {
+        int mapId = response.endedMapIds().get(i);
         mapperEndMap.computeIfAbsent(shuffleId, (id) -> 
ConcurrentHashMap.newKeySet()).add(mapId);
       }
 
-      for (int i = 0; i < response.getPartitionInfoCount(); i++) {
-        PbChangeLocationPartitionInfo partitionInfo = 
response.getPartitionInfo(i);
-        int partitionId = partitionInfo.getPartitionId();
-        int statusCode = partitionInfo.getStatus();
-        if (partitionInfo.getOldAvailable()) {
+      for (Map.Entry<Integer, Tuple3<StatusCode, Boolean, PartitionLocation>> 
entry :
+          response.newLocs().entrySet()) {
+        int partitionId = entry.getKey();
+        StatusCode statusCode = entry.getValue()._1();
+        if (entry.getValue()._2() != null) {
           PartitionLocation oldLoc = oldLocMap.get(partitionId);
           // Currently, revive only check if main location available, here 
won't remove peer loc.
           pushExcludedWorkers.remove(oldLoc.hostAndPushPort());
         }
 
-        if (StatusCode.SUCCESS.getValue() == statusCode) {
-          PartitionLocation loc =
-              
PbSerDeUtils.fromPbPartitionLocation(partitionInfo.getPartition());
+        if (StatusCode.SUCCESS == statusCode) {
+          PartitionLocation loc = entry.getValue()._3();
           partitionLocationMap.put(partitionId, loc);
           pushExcludedWorkers.remove(loc.hostAndPushPort());
           if (loc.hasPeer()) {
             pushExcludedWorkers.remove(loc.getPeer().hostAndPushPort());
           }
-        } else if (StatusCode.STAGE_ENDED.getValue() == statusCode) {
+        } else if (StatusCode.STAGE_ENDED == statusCode) {
           stageEndShuffleSet.add(shuffleId);
           return results;
-        } else if (StatusCode.SHUFFLE_UNREGISTERED.getValue() == statusCode) {
+        } else if (StatusCode.SHUFFLE_UNREGISTERED == statusCode) {
           logger.error("SHUFFLE_NOT_REGISTERED!");
           return null;
         }
-        results.put(partitionId, statusCode);
+        results.put(partitionId, (int) (statusCode.getValue()));
       }
 
       return results;
@@ -1806,7 +1801,8 @@ public class ShuffleClientImpl extends ShuffleClient {
                   pushState.getFailedBatches(),
                   numPartitions,
                   crc32PerPartition,
-                  bytesPerPartition),
+                  bytesPerPartition,
+                  SerdeVersion.V1),
               rpcMaxRetries,
               rpcRetryWait,
               ClassTag$.MODULE$.apply(MapperEndResponse.class));
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 231898535..f48f3cd72 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -156,7 +156,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
   }
 
   case class RegisterCallContext(context: RpcCallContext, partitionId: Int = 
-1) {
-    def reply(response: PbRegisterShuffleResponse) = {
+    def reply(response: RegisterShuffleResponse) = {
       context.reply(response)
     }
   }
@@ -360,14 +360,12 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
   }
 
   override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, 
Unit] = {
-    case pb: PbRegisterShuffle =>
-      val shuffleId = pb.getShuffleId
-      val numMappers = pb.getNumMappers
-      val numPartitions = pb.getNumPartitions
+    case RegisterShuffle(shuffleId, numMappers, numPartitions, serdeVersion) =>
       logDebug(s"Received RegisterShuffle request, " +
         s"$shuffleId, $numMappers, $numPartitions.")
       offerAndReserveSlots(
         RegisterCallContext(context),
+        serdeVersion,
         shuffleId,
         numMappers,
         numPartitions)
@@ -384,31 +382,25 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
       shufflePartitionType.putIfAbsent(shuffleId, PartitionType.MAP)
       offerAndReserveSlots(
         RegisterCallContext(context, partitionId),
+        // Use V1 as this is only supported in java
+        SerdeVersion.V1,
         shuffleId,
         numMappers,
         numMappers,
         partitionId,
         isSegmentGranularityVisible)
 
-    case pb: PbRevive =>
-      val shuffleId = pb.getShuffleId
-      val mapIds = pb.getMapIdList
-      val partitionInfos = pb.getPartitionInfoList
-
+    case Revive(shuffleId, mapIds, reviveRequests, serdeVersion) =>
       val partitionIds = new util.ArrayList[Integer]()
       val epochs = new util.ArrayList[Integer]()
       val oldPartitions = new util.ArrayList[PartitionLocation]()
       val causes = new util.ArrayList[StatusCode]()
-      (0 until partitionInfos.size()).foreach { idx =>
-        val info = partitionInfos.get(idx)
-        partitionIds.add(info.getPartitionId)
-        epochs.add(info.getEpoch)
-        if (info.hasPartition) {
-          
oldPartitions.add(PbSerDeUtils.fromPbPartitionLocation(info.getPartition))
-        } else {
-          oldPartitions.add(null)
-        }
-        causes.add(StatusCode.fromValue(info.getStatus))
+      (0 until reviveRequests.size()).foreach { idx =>
+        val request = reviveRequests.get(idx)
+        partitionIds.add(request.partitionId)
+        epochs.add(request.epoch)
+        oldPartitions.add(request.loc)
+        causes.add(request.cause)
       }
       logDebug(s"Received Revive request, number of partitions 
${partitionIds.size()}")
       handleRevive(
@@ -418,7 +410,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
         partitionIds,
         epochs,
         oldPartitions,
-        causes)
+        causes,
+        serdeVersion)
 
     case pb: PbPartitionSplit =>
       val shuffleId = pb.getShuffleId
@@ -428,7 +421,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
       logTrace(s"Received split request, " +
         s"$shuffleId, $partitionId, $epoch, $oldPartition")
       changePartitionManager.handleRequestPartitionLocation(
-        ChangeLocationsCallContext(context, 1),
+        // TODO: this message is not supported in cppClient yet.
+        ChangeLocationsCallContext(context, 1, SerdeVersion.V1),
         shuffleId,
         partitionId,
         epoch,
@@ -444,7 +438,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
           pushFailedBatch,
           numPartitions,
           crc32PerPartition,
-          bytesWrittenPerPartition) =>
+          bytesWrittenPerPartition,
+          serdeVersion) =>
       logTrace(s"Received MapperEnd TaskEnd request, " +
         s"${Utils.makeMapKey(shuffleId, mapId, attemptId)}")
       val partitionType = getPartitionType(shuffleId)
@@ -459,7 +454,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
             pushFailedBatch,
             numPartitions,
             crc32PerPartition,
-            bytesWrittenPerPartition)
+            bytesWrittenPerPartition,
+            serdeVersion)
         case PartitionType.MAP =>
           handleMapPartitionEnd(
             context,
@@ -467,7 +463,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
             mapId,
             attemptId,
             partitionId,
-            numMappers)
+            numMappers,
+            serdeVersion)
         case _ =>
           throw new UnsupportedOperationException(s"Not support $partitionType 
yet")
       }
@@ -618,6 +615,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
 
   private def offerAndReserveSlots(
       context: RegisterCallContext,
+      serdeVersion: SerdeVersion,
       shuffleId: Int,
       numMappers: Int,
       numPartitions: Int,
@@ -641,13 +639,15 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
               processMapTaskReply(
                 shuffleId,
                 rpcContext,
+                serdeVersion,
                 partitionId,
                 getLatestLocs(shuffleId, p => p.getId == partitionId))
             case PartitionType.REDUCE =>
               if (rpcContext.isInstanceOf[LocalNettyRpcCallContext]) {
                 context.reply(RegisterShuffleResponse(
                   StatusCode.SUCCESS,
-                  getLatestLocs(shuffleId, _ => true)))
+                  getLatestLocs(shuffleId, _ => true),
+                  serdeVersion))
               } else {
                 val cachedMsg = registerShuffleResponseRpcCache.get(
                   shuffleId,
@@ -656,7 +656,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
                       
rpcContext.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(
                         RegisterShuffleResponse(
                           StatusCode.SUCCESS,
-                          getLatestLocs(shuffleId, _ => true)))
+                          getLatestLocs(shuffleId, _ => true),
+                          serdeVersion))
                     }
                   })
                 
rpcContext.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg)
@@ -699,15 +700,16 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
     def processMapTaskReply(
         shuffleId: Int,
         context: RpcCallContext,
+        serdeVersion: SerdeVersion,
         partitionId: Int,
         partitionLocations: Array[PartitionLocation]): Unit = {
       // if any partition location resource exist just reply
       if (partitionLocations.size > 0) {
-        context.reply(RegisterShuffleResponse(StatusCode.SUCCESS, 
partitionLocations))
+        context.reply(RegisterShuffleResponse(StatusCode.SUCCESS, 
partitionLocations, serdeVersion))
       } else {
         // request new resource for this task
         changePartitionManager.handleRequestPartitionLocation(
-          ApplyNewLocationCallContext(context),
+          ApplyNewLocationCallContext(context, serdeVersion),
           shuffleId,
           partitionId,
           -1,
@@ -717,13 +719,13 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
     }
 
     // Reply to all RegisterShuffle request for current shuffle id.
-    def replyRegisterShuffle(response: PbRegisterShuffleResponse): Unit = {
+    def replyRegisterShuffle(response: RegisterShuffleResponse): Unit = {
       registeringShuffleRequest.synchronized {
         val serializedMsg: Option[ByteBuffer] = partitionType match {
           case PartitionType.REDUCE =>
             context.context match {
               case remoteContext: RemoteNettyRpcCallContext =>
-                if (response.getStatus == StatusCode.SUCCESS.getValue) {
+                if (response.status == StatusCode.SUCCESS) {
                   Option(remoteContext.nettyEnv.serialize(
                     response))
                 } else {
@@ -735,19 +737,19 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
           case _ => Option.empty
         }
 
-        val locations = PbSerDeUtils.fromPbPackedPartitionLocationsPair(
-          response.getPackedPartitionLocationsPair)._1.asScala
+        val locations = response.partitionLocations
 
         registeringShuffleRequest.asScala
           .get(shuffleId)
           .foreach(_.asScala.foreach(context => {
             partitionType match {
               case PartitionType.MAP =>
-                if (response.getStatus == StatusCode.SUCCESS.getValue) {
-                  val partitionLocations = locations.filter(_.getId == 
context.partitionId).toArray
+                if (response.status == StatusCode.SUCCESS) {
+                  val partitionLocations = locations.filter(_.getId == 
context.partitionId)
                   processMapTaskReply(
                     shuffleId,
                     context.context,
+                    serdeVersion,
                     context.partitionId,
                     partitionLocations)
                 } else {
@@ -757,7 +759,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
                 }
               case PartitionType.REDUCE =>
                 if (context.context.isInstanceOf[
-                    LocalNettyRpcCallContext] || response.getStatus != 
StatusCode.SUCCESS.getValue) {
+                    LocalNettyRpcCallContext] || response.status != 
StatusCode.SUCCESS) {
                   context.reply(response)
                 } else {
                   registerShuffleResponseRpcCache.put(shuffleId, 
serializedMsg.get)
@@ -780,17 +782,26 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
     res.status match {
       case StatusCode.REQUEST_FAILED =>
         logInfo(s"OfferSlots RPC request failed for $shuffleId!")
-        
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.REQUEST_FAILED, 
Array.empty))
+        replyRegisterShuffle(RegisterShuffleResponse(
+          StatusCode.REQUEST_FAILED,
+          Array.empty,
+          serdeVersion))
         return
       case StatusCode.SLOT_NOT_AVAILABLE =>
         logInfo(s"OfferSlots for $shuffleId failed!")
-        
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.SLOT_NOT_AVAILABLE, 
Array.empty))
+        replyRegisterShuffle(RegisterShuffleResponse(
+          StatusCode.SLOT_NOT_AVAILABLE,
+          Array.empty,
+          serdeVersion))
         return
       case StatusCode.SUCCESS =>
         logDebug(s"OfferSlots for $shuffleId Success!Slots Info: 
${res.workerResource}")
       case StatusCode.WORKER_EXCLUDED =>
         logInfo(s"OfferSlots for $shuffleId failed due to all workers be 
excluded!")
-        
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.WORKER_EXCLUDED, 
Array.empty))
+        replyRegisterShuffle(RegisterShuffleResponse(
+          StatusCode.WORKER_EXCLUDED,
+          Array.empty,
+          serdeVersion))
         return
       case _ => // won't happen
         throw new UnsupportedOperationException()
@@ -823,7 +834,10 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
     // If reserve slots failed, clear allocated resources, reply 
ReserveSlotFailed and return.
     if (!reserveSlotsSuccess) {
       logError(s"reserve buffer for $shuffleId failed, reply to all.")
-      
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.RESERVE_SLOTS_FAILED, 
Array.empty))
+      replyRegisterShuffle(RegisterShuffleResponse(
+        StatusCode.RESERVE_SLOTS_FAILED,
+        Array.empty,
+        serdeVersion))
     } else {
       if (log.isDebugEnabled()) {
         logDebug(s"ReserveSlots for $shuffleId success with details:$slots!")
@@ -851,7 +865,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
       val allPrimaryPartitionLocations = 
slots.asScala.flatMap(_._2._1.asScala).toArray
       replyRegisterShuffle(RegisterShuffleResponse(
         StatusCode.SUCCESS,
-        allPrimaryPartitionLocations))
+        allPrimaryPartitionLocations,
+        serdeVersion))
     }
   }
 
@@ -862,9 +877,10 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
       partitionIds: util.List[Integer],
       oldEpochs: util.List[Integer],
       oldPartitions: util.List[PartitionLocation],
-      causes: util.List[StatusCode]): Unit = {
+      causes: util.List[StatusCode],
+      serdeVersion: SerdeVersion): Unit = {
     val contextWrapper =
-      ChangeLocationsCallContext(context, partitionIds.size())
+      ChangeLocationsCallContext(context, partitionIds.size(), serdeVersion)
     // If shuffle not registered, reply ShuffleNotRegistered and return
     if (!registeredShuffle.contains(shuffleId)) {
       logError(s"[handleRevive] shuffle $shuffleId not registered!")
@@ -916,7 +932,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
       pushFailedBatches: util.Map[String, LocationPushFailedBatches],
       numPartitions: Int,
       crc32PerPartition: Array[Int],
-      bytesWrittenPerPartition: Array[Long]): Unit = {
+      bytesWrittenPerPartition: Array[Long],
+      serdeVersion: SerdeVersion): Unit = {
 
     val (mapperAttemptFinishedSuccess, allMapperFinished) =
       commitManager.finishMapperAttempt(
@@ -936,7 +953,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
     }
 
     // reply success
-    context.reply(MapperEndResponse(StatusCode.SUCCESS))
+    context.reply(MapperEndResponse(StatusCode.SUCCESS, serdeVersion))
   }
 
   private def handleGetReducerFileGroup(
@@ -1205,7 +1222,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
       mapId: Int,
       attemptId: Int,
       partitionId: Int,
-      numMappers: Int): Unit = {
+      numMappers: Int,
+      serdeVersion: SerdeVersion): Unit = {
     def reply(result: Boolean): Unit = {
       val message =
         s"to handle MapPartitionEnd for ${Utils.makeMapKey(shuffleId, mapId, 
attemptId)}, " +
@@ -1213,10 +1231,10 @@ class LifecycleManager(val appUniqueId: String, val 
conf: CelebornConf) extends
       result match {
         case true => // if already committed by another try
           logDebug(s"Succeed $message")
-          context.reply(MapperEndResponse(StatusCode.SUCCESS))
+          context.reply(MapperEndResponse(StatusCode.SUCCESS, serdeVersion))
         case false =>
           logError(s"Failed $message, reply ${StatusCode.SHUFFLE_DATA_LOST}.")
-          context.reply(MapperEndResponse(StatusCode.SHUFFLE_DATA_LOST))
+          context.reply(MapperEndResponse(StatusCode.SHUFFLE_DATA_LOST, 
serdeVersion))
       }
     }
 
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala
 
b/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala
index 091960a4c..9de71dd46 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala
@@ -20,6 +20,7 @@ package org.apache.celeborn.client
 import java.util
 
 import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.network.protocol.SerdeVersion
 import org.apache.celeborn.common.protocol.PartitionLocation
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.{ChangeLocationResponse,
 RegisterShuffleResponse}
 import org.apache.celeborn.common.protocol.message.StatusCode
@@ -36,11 +37,12 @@ trait RequestLocationCallContext {
 
 case class ChangeLocationsCallContext(
     context: RpcCallContext,
-    partitionCount: Int)
+    partitionCount: Int,
+    serdeVersion: SerdeVersion)
   extends RequestLocationCallContext with Logging {
-  val endedMapIds = new util.HashSet[Integer]()
+  val endedMapIds = new util.ArrayList[Integer]()
   val newLocs =
-    JavaUtils.newConcurrentHashMap[Integer, (StatusCode, Boolean, 
PartitionLocation)](
+    JavaUtils.newConcurrentHashMap[Integer, (StatusCode, java.lang.Boolean, 
PartitionLocation)](
       partitionCount)
 
   def markMapperEnd(mapId: Int): Unit = this.synchronized {
@@ -59,12 +61,13 @@ case class ChangeLocationsCallContext(
 
     if (newLocs.size() == partitionCount || StatusCode.SHUFFLE_UNREGISTERED == 
status
       || StatusCode.STAGE_ENDED == status) {
-      context.reply(ChangeLocationResponse(endedMapIds, newLocs))
+      context.reply(ChangeLocationResponse(endedMapIds, newLocs, serdeVersion))
     }
   }
 }
 
-case class ApplyNewLocationCallContext(context: RpcCallContext) extends 
RequestLocationCallContext {
+case class ApplyNewLocationCallContext(context: RpcCallContext, serdeVersion: 
SerdeVersion)
+  extends RequestLocationCallContext {
   override def reply(
       partitionId: Int,
       status: StatusCode,
@@ -72,8 +75,8 @@ case class ApplyNewLocationCallContext(context: 
RpcCallContext) extends RequestL
       available: Boolean): Unit = {
     partitionLocationOpt match {
       case Some(partitionLocation) =>
-        context.reply(RegisterShuffleResponse(status, 
Array(partitionLocation)))
-      case None => context.reply(RegisterShuffleResponse(status, Array.empty))
+        context.reply(RegisterShuffleResponse(status, 
Array(partitionLocation), serdeVersion))
+      case None => context.reply(RegisterShuffleResponse(status, Array.empty, 
serdeVersion))
     }
   }
 }
diff --git 
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java 
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java
index 7a7706973..2cc2cd1fd 100644
--- 
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java
+++ 
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java
@@ -30,9 +30,9 @@ import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.identity.UserIdentifier;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
+import org.apache.celeborn.common.network.protocol.SerdeVersion;
 import org.apache.celeborn.common.protocol.CompressionCodec;
 import org.apache.celeborn.common.protocol.PartitionLocation;
-import org.apache.celeborn.common.protocol.PbRegisterShuffleResponse;
 import org.apache.celeborn.common.protocol.message.ControlMessages;
 import org.apache.celeborn.common.protocol.message.StatusCode;
 import org.apache.celeborn.common.rpc.RpcEndpointRef;
@@ -90,12 +90,14 @@ public abstract class ShuffleClientBaseSuiteJ {
     primaryLocation.setPeer(replicaLocation);
 
     when(endpointRef.askSync(
-            ControlMessages.RegisterShuffle$.MODULE$.apply(TEST_SHUFFLE_ID, 1, 
1),
-            ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)))
+            new ControlMessages.RegisterShuffle(TEST_SHUFFLE_ID, 1, 1, 
SerdeVersion.V1),
+            
ClassTag$.MODULE$.apply(ControlMessages.RegisterShuffleResponse.class)))
         .thenAnswer(
             t ->
-                ControlMessages.RegisterShuffleResponse$.MODULE$.apply(
-                    StatusCode.SUCCESS, new PartitionLocation[] 
{primaryLocation}));
+                new ControlMessages.RegisterShuffleResponse(
+                    StatusCode.SUCCESS,
+                    new PartitionLocation[] {primaryLocation},
+                    SerdeVersion.V1));
 
     shuffleClient.setupLifecycleManagerRef(endpointRef);
     when(clientFactory.createClient(
diff --git 
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java 
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
index 0f4b5c30f..e6d450d87 100644
--- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
+++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
@@ -263,13 +263,13 @@ public class ShuffleClientSuiteJ {
         .thenAnswer(
             t ->
                 RegisterShuffleResponse$.MODULE$.apply(
-                    statusCode, new PartitionLocation[] {primaryLocation}));
+                    statusCode, new PartitionLocation[] {primaryLocation}, 
SerdeVersion.V1));
 
     when(endpointRef.askSync(any(), any(), any(Integer.class), 
any(Long.class), any()))
         .thenAnswer(
             t ->
                 RegisterShuffleResponse$.MODULE$.apply(
-                    statusCode, new PartitionLocation[] {primaryLocation}));
+                    statusCode, new PartitionLocation[] {primaryLocation}, 
SerdeVersion.V1));
 
     shuffleClient.setupLifecycleManagerRef(endpointRef);
 
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index eb9274632..36f164d69 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -131,17 +131,12 @@ object ControlMessages extends Logging {
       workerEvent: WorkerEventType = WorkerEventType.None)
     extends MasterMessage
 
-  object RegisterShuffle {
-    def apply(
-        shuffleId: Int,
-        numMappers: Int,
-        numPartitions: Int): PbRegisterShuffle =
-      PbRegisterShuffle.newBuilder()
-        .setShuffleId(shuffleId)
-        .setNumMappers(numMappers)
-        .setNumPartitions(numPartitions)
-        .build()
-  }
+  case class RegisterShuffle(
+      shuffleId: Int,
+      numMappers: Int,
+      numPartitions: Int,
+      serdeVersion: SerdeVersion)
+    extends MasterMessage
 
   object RegisterMapPartitionTask {
     def apply(
@@ -161,17 +156,10 @@ object ControlMessages extends Logging {
         .build()
   }
 
-  object RegisterShuffleResponse {
-    def apply(
-        status: StatusCode,
-        partitionLocations: Array[PartitionLocation]): 
PbRegisterShuffleResponse = {
-      val builder = PbRegisterShuffleResponse.newBuilder()
-        .setStatus(status.getValue)
-      builder.setPackedPartitionLocationsPair(
-        
PbSerDeUtils.toPbPackedPartitionLocationsPair(partitionLocations.toList))
-      builder.build()
-    }
-  }
+  case class RegisterShuffleResponse(
+      status: StatusCode,
+      partitionLocations: Array[PartitionLocation],
+      serdeVersion: SerdeVersion) extends MasterMessage
 
   case class RequestSlots(
       applicationId: String,
@@ -195,29 +183,11 @@ object ControlMessages extends Logging {
       packed: Boolean = false)
     extends MasterMessage
 
-  object Revive {
-    def apply(
-        shuffleId: Int,
-        mapIds: util.Set[Integer],
-        reviveRequests: util.Collection[ReviveRequest]): PbRevive = {
-      val builder = PbRevive.newBuilder()
-        .setShuffleId(shuffleId)
-        .addAllMapId(mapIds)
-
-      reviveRequests.asScala.foreach { req =>
-        val partitionInfoBuilder = PbRevivePartitionInfo.newBuilder()
-          .setPartitionId(req.partitionId)
-          .setEpoch(req.epoch)
-          .setStatus(req.cause.getValue)
-        if (req.loc != null) {
-          
partitionInfoBuilder.setPartition(PbSerDeUtils.toPbPartitionLocation(req.loc))
-        }
-        builder.addPartitionInfo(partitionInfoBuilder.build())
-      }
-
-      builder.build()
-    }
-  }
+  case class Revive(
+      shuffleId: Int,
+      mapIds: util.List[Integer],
+      reviveRequests: util.List[ReviveRequest],
+      serdeVersion: SerdeVersion) extends MasterMessage
 
   object PartitionSplit {
     def apply(
@@ -233,26 +203,10 @@ object ControlMessages extends Logging {
         .build()
   }
 
-  object ChangeLocationResponse {
-    def apply(
-        mapIds: util.Set[Integer],
-        newLocs: util.Map[Integer, (StatusCode, Boolean, PartitionLocation)])
-        : PbChangeLocationResponse = {
-      val builder = PbChangeLocationResponse.newBuilder()
-      builder.addAllEndedMapId(mapIds)
-      newLocs.asScala.foreach { case (partitionId, (status, available, loc)) =>
-        val pbChangeLocationPartitionInfoBuilder = 
PbChangeLocationPartitionInfo.newBuilder()
-          .setPartitionId(partitionId)
-          .setStatus(status.getValue)
-          .setOldAvailable(available)
-        if (loc != null) {
-          
pbChangeLocationPartitionInfoBuilder.setPartition(PbSerDeUtils.toPbPartitionLocation(loc))
-        }
-        builder.addPartitionInfo(pbChangeLocationPartitionInfoBuilder.build())
-      }
-      builder.build()
-    }
-  }
+  case class ChangeLocationResponse(
+      endedMapIds: util.List[Integer],
+      newLocs: util.Map[Integer, (StatusCode, java.lang.Boolean, 
PartitionLocation)],
+      serdeVersion: SerdeVersion) extends MasterMessage
 
   case class MapperEnd(
       shuffleId: Int,
@@ -263,7 +217,8 @@ object ControlMessages extends Logging {
       failedBatchSet: util.Map[String, LocationPushFailedBatches],
       numPartitions: Int,
       crc32PerPartition: Array[Int],
-      bytesWrittenPerPartition: Array[Long])
+      bytesWrittenPerPartition: Array[Long],
+      serdeVersion: SerdeVersion)
     extends MasterMessage
 
   case class ReadReducerPartitionEnd(
@@ -275,7 +230,7 @@ object ControlMessages extends Logging {
       bytesWritten: Long)
     extends MasterMessage
 
-  case class MapperEndResponse(status: StatusCode) extends MasterMessage
+  case class MapperEndResponse(status: StatusCode, serdeVersion: SerdeVersion) 
extends MasterMessage
 
   case class ReadReducerPartitionEndResponse(status: StatusCode) extends 
MasterMessage
 
@@ -674,14 +629,23 @@ object ControlMessages extends Logging {
         .build().toByteArray
       new TransportMessage(MessageType.HEARTBEAT_FROM_WORKER_RESPONSE, payload)
 
-    case pb: PbRegisterShuffle =>
-      new TransportMessage(MessageType.REGISTER_SHUFFLE, pb.toByteArray)
+    case RegisterShuffle(shuffleId, numMappers, numPartitions, serdeVersion) =>
+      val payload = PbRegisterShuffle.newBuilder()
+        .setShuffleId(shuffleId)
+        .setNumMappers(numMappers)
+        .setNumPartitions(numPartitions)
+        .build().toByteArray
+      new TransportMessage(MessageType.REGISTER_SHUFFLE, payload, serdeVersion)
 
     case pb: PbRegisterMapPartitionTask =>
       new TransportMessage(MessageType.REGISTER_MAP_PARTITION_TASK, 
pb.toByteArray)
 
-    case pb: PbRegisterShuffleResponse =>
-      new TransportMessage(MessageType.REGISTER_SHUFFLE_RESPONSE, 
pb.toByteArray)
+    case RegisterShuffleResponse(status, partitionLocations, serdeVersion) =>
+      val payload = PbRegisterShuffleResponse.newBuilder()
+        .setStatus(status.getValue).setPackedPartitionLocationsPair(
+          
PbSerDeUtils.toPbPackedPartitionLocationsPair(partitionLocations.toList))
+        .build().toByteArray
+      new TransportMessage(MessageType.REGISTER_SHUFFLE_RESPONSE, payload, 
serdeVersion)
 
     case RequestSlots(
           applicationId,
@@ -729,11 +693,39 @@ object ControlMessages extends Logging {
       val payload = builder.build().toByteArray
       new TransportMessage(MessageType.REQUEST_SLOTS_RESPONSE, payload)
 
-    case pb: PbRevive =>
-      new TransportMessage(MessageType.CHANGE_LOCATION, pb.toByteArray)
+    case Revive(shuffleId, mapIds, reviveRequests, serdeVersion) =>
+      val builder = PbRevive.newBuilder()
+        .setShuffleId(shuffleId)
+        .addAllMapId(mapIds)
 
-    case pb: PbChangeLocationResponse =>
-      new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, 
pb.toByteArray)
+      reviveRequests.asScala.foreach { req =>
+        val partitionInfoBuilder = PbRevivePartitionInfo.newBuilder()
+          .setPartitionId(req.partitionId)
+          .setEpoch(req.epoch)
+          .setStatus(req.cause.getValue)
+        if (req.loc != null) {
+          
partitionInfoBuilder.setPartition(PbSerDeUtils.toPbPartitionLocation(req.loc))
+        }
+        builder.addPartitionInfo(partitionInfoBuilder.build())
+      }
+      val payload = builder.build().toByteArray
+      new TransportMessage(MessageType.CHANGE_LOCATION, payload, serdeVersion)
+
+    case ChangeLocationResponse(mapIds, newLocs, serdeVersion) =>
+      val builder = PbChangeLocationResponse.newBuilder()
+      builder.addAllEndedMapId(mapIds)
+      newLocs.asScala.foreach { case (partitionId, (status, available, loc)) =>
+        val pbChangeLocationPartitionInfoBuilder = 
PbChangeLocationPartitionInfo.newBuilder()
+          .setPartitionId(partitionId)
+          .setStatus(status.getValue)
+          .setOldAvailable(available)
+        if (loc != null) {
+          
pbChangeLocationPartitionInfoBuilder.setPartition(PbSerDeUtils.toPbPartitionLocation(loc))
+        }
+        builder.addPartitionInfo(pbChangeLocationPartitionInfoBuilder.build())
+      }
+      val payload = builder.build().toByteArray
+      new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, payload, 
serdeVersion)
 
     case MapperEnd(
           shuffleId,
@@ -744,7 +736,8 @@ object ControlMessages extends Logging {
           pushFailedBatch,
           numPartitions,
           crc32PerPartition,
-          bytesWrittenPerPartition) =>
+          bytesWrittenPerPartition,
+          serdeVersion) =>
       val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) =>
         val resultValue = PbSerDeUtils.toPbLocationPushFailedBatches(v)
         (k, resultValue)
@@ -761,13 +754,13 @@ object ControlMessages extends Logging {
         .addAllBytesWrittenPerPartition(bytesWrittenPerPartition.map(
           java.lang.Long.valueOf).toSeq.asJava)
         .build().toByteArray
-      new TransportMessage(MessageType.MAPPER_END, payload)
+      new TransportMessage(MessageType.MAPPER_END, payload, serdeVersion)
 
-    case MapperEndResponse(status) =>
+    case MapperEndResponse(status, serdeVersion) =>
       val payload = PbMapperEndResponse.newBuilder()
         .setStatus(status.getValue)
         .build().toByteArray
-      new TransportMessage(MessageType.MAPPER_END_RESPONSE, payload)
+      new TransportMessage(MessageType.MAPPER_END_RESPONSE, payload, 
serdeVersion)
 
     case GetReducerFileGroup(shuffleId, isSegmentGranularityVisible, 
serdeVersion) =>
       val payload = PbGetReducerFileGroup.newBuilder()
@@ -1132,13 +1125,23 @@ object ControlMessages extends Logging {
           pbHeartbeatFromWorkerResponse.getWorkerEventType)
 
       case REGISTER_SHUFFLE_VALUE =>
-        PbRegisterShuffle.parseFrom(message.getPayload)
+        val pbRegisterShuffle = PbRegisterShuffle.parseFrom(message.getPayload)
+        RegisterShuffle(
+          pbRegisterShuffle.getShuffleId,
+          pbRegisterShuffle.getNumMappers,
+          pbRegisterShuffle.getNumPartitions,
+          message.getSerdeVersion)
 
       case REGISTER_MAP_PARTITION_TASK_VALUE =>
         PbRegisterMapPartitionTask.parseFrom(message.getPayload)
 
       case REGISTER_SHUFFLE_RESPONSE_VALUE =>
-        PbRegisterShuffleResponse.parseFrom(message.getPayload)
+        val pbRegisterShuffleResponse = 
PbRegisterShuffleResponse.parseFrom(message.getPayload)
+        RegisterShuffleResponse(
+          StatusCode.fromValue(pbRegisterShuffleResponse.getStatus),
+          PbSerDeUtils.fromPbPackedPartitionLocationsPair(
+            
pbRegisterShuffleResponse.getPackedPartitionLocationsPair)._1.asScala.toArray,
+          message.getSerdeVersion)
 
       case REQUEST_SLOTS_VALUE =>
         val pbRequestSlots = PbRequestSlots.parseFrom(message.getPayload)
@@ -1175,10 +1178,51 @@ object ControlMessages extends Logging {
           workerResource)
 
       case CHANGE_LOCATION_VALUE =>
-        PbRevive.parseFrom(message.getPayload)
+        val pbRevive = PbRevive.parseFrom(message.getPayload)
+        val shuffleId = pbRevive.getShuffleId
+        val partitionInfos = pbRevive.getPartitionInfoList
+        val reviveRequests = new util.ArrayList[ReviveRequest]()
+        (0 until partitionInfos.size).foreach { idx =>
+          val info = partitionInfos.get(idx)
+          var partition: PartitionLocation = null
+          if (info.hasPartition) {
+            partition = PbSerDeUtils.fromPbPartitionLocation(info.getPartition)
+          }
+          val reviveRequest = new ReviveRequest(
+            shuffleId,
+            -1,
+            -1,
+            info.getPartitionId,
+            info.getEpoch,
+            partition,
+            StatusCode.fromValue(info.getStatus))
+          reviveRequests.add(reviveRequest)
+        }
+        Revive(
+          pbRevive.getShuffleId,
+          pbRevive.getMapIdList,
+          reviveRequests,
+          message.getSerdeVersion)
 
       case CHANGE_LOCATION_RESPONSE_VALUE =>
-        PbChangeLocationResponse.parseFrom(message.getPayload)
+        val pbChangeLocationResponse = 
PbChangeLocationResponse.parseFrom(message.getPayload)
+        val newLocs =
+          new util.HashMap[Integer, (StatusCode, java.lang.Boolean, 
PartitionLocation)]()
+        val partitionInfos = pbChangeLocationResponse.getPartitionInfoList
+        (0 until partitionInfos.size).foreach { idx =>
+          val info = partitionInfos.get(idx)
+          var partition: PartitionLocation = null
+          if (info.hasPartition) {
+            partition = PbSerDeUtils.fromPbPartitionLocation(info.getPartition)
+          }
+          newLocs.put(
+            info.getPartitionId,
+            (StatusCode.fromValue(info.getStatus), info.getOldAvailable, 
partition))
+        }
+        ChangeLocationResponse(
+          pbChangeLocationResponse.getEndedMapIdList,
+          newLocs,
+          message.getSerdeVersion)
 
       case MAPPER_END_VALUE =>
         val pbMapperEnd = PbMapperEnd.parseFrom(message.getPayload)
@@ -1203,7 +1247,8 @@ object ControlMessages extends Logging {
           }.toMap.asJava,
           pbMapperEnd.getNumPartitions,
           crc32Array,
-          bytesWrittenPerPartitionArray)
+          bytesWrittenPerPartitionArray,
+          message.getSerdeVersion)
 
       case READ_REDUCER_PARTITION_END_VALUE =>
         val pbReadReducerPartitionEnd = 
PbReadReducerPartitionEnd.parseFrom(message.getPayload)
@@ -1220,7 +1265,9 @@ object ControlMessages extends Logging {
 
       case MAPPER_END_RESPONSE_VALUE =>
         val pbMapperEndResponse = 
PbMapperEndResponse.parseFrom(message.getPayload)
-        MapperEndResponse(StatusCode.fromValue(pbMapperEndResponse.getStatus))
+        MapperEndResponse(
+          StatusCode.fromValue(pbMapperEndResponse.getStatus),
+          message.getSerdeVersion)
 
       case GET_REDUCER_FILE_GROUP_VALUE =>
         val pbGetReducerFileGroup = 
PbGetReducerFileGroup.parseFrom(message.getPayload)
diff --git 
a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala 
b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
index 7cadaf07e..8be472b64 100644
--- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
+++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
@@ -28,6 +28,7 @@ import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.client.{MasterEndpointResolver, 
StaticMasterEndpointResolver}
 import org.apache.celeborn.common.exception.CelebornException
 import org.apache.celeborn.common.identity.DefaultIdentityProvider
+import org.apache.celeborn.common.network.protocol.SerdeVersion
 import org.apache.celeborn.common.protocol.{PartitionLocation, 
TransportModuleConstants}
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse,
 MapperEnd}
 import org.apache.celeborn.common.protocol.message.StatusCode
@@ -149,7 +150,17 @@ class UtilsSuite extends CelebornFunSuite {
 
   test("MapperEnd class convert with pb") {
     val mapperEnd =
-      MapperEnd(1, 1, 1, 2, 1, Collections.emptyMap(), 1, Array.emptyIntArray, 
Array.emptyLongArray)
+      MapperEnd(
+        1,
+        1,
+        1,
+        2,
+        1,
+        Collections.emptyMap(),
+        1,
+        Array.emptyIntArray,
+        Array.emptyLongArray,
+        SerdeVersion.V1)
     val mapperEndTrans =
       
Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd]
     assert(mapperEnd.shuffleId == mapperEndTrans.shuffleId)
diff --git a/cpp/celeborn/tests/CMakeLists.txt 
b/cpp/celeborn/tests/CMakeLists.txt
index 0bc5e41c9..104607ce2 100644
--- a/cpp/celeborn/tests/CMakeLists.txt
+++ b/cpp/celeborn/tests/CMakeLists.txt
@@ -35,4 +35,29 @@ target_link_libraries(
 
 add_executable(cppDataSumWithReaderClient DataSumWithReaderClient.cpp)
 
-target_link_libraries(cppDataSumWithReaderClient dataSumWithReaderClient)
\ No newline at end of file
+target_link_libraries(cppDataSumWithReaderClient dataSumWithReaderClient)
+
+add_library(
+        dataSumWithWriterClient
+        DataSumWithWriterClient.cpp)
+
+target_link_libraries(
+        dataSumWithWriterClient
+        memory
+        utils
+        conf
+        proto
+        network
+        protocol
+        client
+        ${WANGLE}
+        ${FIZZ}
+        ${LIBSODIUM_LIBRARY}
+        ${FOLLY_WITH_DEPENDENCIES}
+        ${GLOG}
+        ${GFLAGS_LIBRARIES}
+)
+
+add_executable(cppDataSumWithWriterClient DataSumWithWriterClient.cpp)
+
+target_link_libraries(cppDataSumWithWriterClient dataSumWithWriterClient)
diff --git a/cpp/celeborn/tests/DataSumWithWriterClient.cpp 
b/cpp/celeborn/tests/DataSumWithWriterClient.cpp
new file mode 100644
index 000000000..ceaa01beb
--- /dev/null
+++ b/cpp/celeborn/tests/DataSumWithWriterClient.cpp
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/init/Init.h>
+#include <cstdio>
+#include <fstream>
+#include <iostream>
+
+#include <celeborn/client/ShuffleClient.h>
+
+int main(int argc, char** argv) {
+  folly::init(&argc, &argv, false);
+  // Read the configs.
+  assert(argc == 9);
+  std::string lifecycleManagerHost = argv[1];
+  int lifecycleManagerPort = std::atoi(argv[2]);
+  std::string appUniqueId = argv[3];
+  int shuffleId = std::atoi(argv[4]);
+  int attemptId = std::atoi(argv[5]);
+  int numMappers = std::atoi(argv[6]);
+  int numPartitions = std::atoi(argv[7]);
+  std::string resultFile = argv[8];
+  std::cout << "lifecycleManagerHost = " << lifecycleManagerHost
+            << ", lifecycleManagerPort = " << lifecycleManagerPort
+            << ", appUniqueId = " << appUniqueId
+            << ", shuffleId = " << shuffleId << ", attemptId = " << attemptId
+            << ", numMappers = " << numMappers
+            << ", numPartitions = " << numPartitions
+            << ", resultFile = " << resultFile << std::endl;
+
+  // Create shuffleClient and setup.
+  auto conf = std::make_shared<celeborn::conf::CelebornConf>();
+  auto clientEndpoint =
+      std::make_shared<celeborn::client::ShuffleClientEndpoint>(conf);
+  auto shuffleClient = celeborn::client::ShuffleClientImpl::create(
+      appUniqueId, conf, *clientEndpoint);
+  shuffleClient->setupLifecycleManagerRef(
+      lifecycleManagerHost, lifecycleManagerPort);
+
+  long maxData = 1000000;
+  size_t numData = 1000;
+  // Generate data, sum up and pushData.
+  std::vector<long> result(numPartitions, 0);
+  std::vector<size_t> dataCnt(numPartitions, 0);
+  for (int mapId = 0; mapId < numMappers; mapId++) {
+    for (int partitionId = 0; partitionId < numPartitions; partitionId++) {
+      std::string partitionData;
+      for (size_t i = 0; i < numData; i++) {
+        int data = std::rand() % maxData;
+        result[partitionId] += data;
+        dataCnt[partitionId]++;
+        partitionData += "-" + std::to_string(data);
+      }
+      shuffleClient->pushData(
+          shuffleId,
+          mapId,
+          attemptId,
+          partitionId,
+          reinterpret_cast<const uint8_t*>(partitionData.c_str()),
+          0,
+          partitionData.size(),
+          numMappers,
+          numPartitions);
+    }
+    shuffleClient->mapperEnd(shuffleId, mapId, attemptId, numMappers);
+  }
+  for (int partitionId = 0; partitionId < numPartitions; partitionId++) {
+    std::cout << "partition " << partitionId
+              << " sum result = " << result[partitionId]
+              << ", dataCnt = " << dataCnt[partitionId] << std::endl;
+  }
+
+  // Write result to resultFile.
+  remove(resultFile.c_str());
+  std::ofstream of(resultFile);
+  for (int partitionId = 0; partitionId < numPartitions; partitionId++) {
+    of << result[partitionId] << std::endl;
+  }
+  of.close();
+
+  return 0;
+}
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/CppWriteJavaReadTestWithNONE.scala
similarity index 88%
copy from 
worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala
copy to 
worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/CppWriteJavaReadTestWithNONE.scala
index bc1961384..b7fb62a4a 100644
--- 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/CppWriteJavaReadTestWithNONE.scala
@@ -19,9 +19,9 @@ package org.apache.celeborn.service.deploy.cluster
 
 import org.apache.celeborn.common.protocol.CompressionCodec
 
-object JavaWriteCppReadTestWithLZ4 extends JavaWriteCppReadTestBase {
+object CppWriteJavaReadTestWithNONE extends JavaCppHybridReadWriteTestBase {
 
   def main(args: Array[String]) = {
-    testJavaWriteCppRead(CompressionCodec.LZ4)
+    testCppWriteJavaRead(CompressionCodec.NONE)
   }
 }
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestBase.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala
similarity index 60%
rename from 
worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestBase.scala
rename to 
worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala
index e059754e3..325f9c8b7 100644
--- 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestBase.scala
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala
@@ -28,6 +28,7 @@ import org.scalatest.BeforeAndAfterAll
 import org.scalatest.funsuite.AnyFunSuite
 
 import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl}
+import org.apache.celeborn.client.read.MetricsCallback
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.identity.UserIdentifier
 import org.apache.celeborn.common.internal.Logging
@@ -35,7 +36,7 @@ import org.apache.celeborn.common.protocol.CompressionCodec
 import org.apache.celeborn.common.util.Utils.runCommand
 import org.apache.celeborn.service.deploy.MiniClusterFeature
 
-trait JavaWriteCppReadTestBase extends AnyFunSuite
+trait JavaCppHybridReadWriteTestBase extends AnyFunSuite
   with Logging with MiniClusterFeature with BeforeAndAfterAll {
 
   var masterPort = 0
@@ -147,4 +148,99 @@ trait JavaWriteCppReadTestBase extends AnyFunSuite
     shuffleClient.shutdown()
   }
 
+  def testCppWriteJavaRead(codec: CompressionCodec): Unit = {
+    beforeAll()
+    try {
+      runCppWriteJavaRead(codec)
+    } finally {
+      afterAll()
+    }
+  }
+
+  def runCppWriteJavaRead(codec: CompressionCodec): Unit = {
+    val appUniqueId = "test-app"
+    val shuffleId = 0
+    val attemptId = 0
+
+    // Create lifecycleManager.
+    val clientConf = new CelebornConf()
+      .set(CelebornConf.MASTER_ENDPOINTS.key, s"localhost:$masterPort")
+      .set(CelebornConf.SHUFFLE_COMPRESSION_CODEC.key, codec.name)
+      .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "true")
+      .set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "256K")
+      .set(CelebornConf.READ_LOCAL_SHUFFLE_FILE, false)
+      .set("celeborn.data.io.numConnectionsPerPeer", "1")
+    val lifecycleManager = new LifecycleManager(appUniqueId, clientConf)
+
+    // Create writer shuffleClient.
+    val shuffleClient =
+      new ShuffleClientImpl(appUniqueId, clientConf, UserIdentifier("mock", 
"mock"))
+    shuffleClient.setupLifecycleManagerRef(lifecycleManager.self)
+
+    val numMappers = 2
+    val numPartitions = 2
+
+    // Launch cpp writer to write data, calculate result and write to specific 
result file.
+    val cppResultFile = "/tmp/celeborn-cpp-writer-result.txt"
+    val lifecycleManagerHost = lifecycleManager.getHost
+    val lifecycleManagerPort = lifecycleManager.getPort
+    val projectDirectory = new File(new File(".").getAbsolutePath)
+    val cppBinRelativeDirectory = "cpp/build/celeborn/tests/"
+    val cppBinFileName = "cppDataSumWithWriterClient"
+    val cppBinFilePath = 
s"$projectDirectory/$cppBinRelativeDirectory/$cppBinFileName"
+    // Execution command: $exec lifecycleManagerHost lifecycleManagerPort 
appUniqueId shuffleId attemptId numMappers numPartitions cppResultFile
+    val command = {
+      s"$cppBinFilePath $lifecycleManagerHost $lifecycleManagerPort 
$appUniqueId $shuffleId $attemptId $numMappers $numPartitions $cppResultFile"
+    }
+    println(s"run command: $command")
+    val commandOutput = runCommand(command)
+    println(s"command output: $commandOutput")
+
+    val metricsCallback = new MetricsCallback {
+      override def incBytesRead(bytesWritten: Long): Unit = {}
+      override def incReadTime(time: Long): Unit = {}
+    }
+
+    var sums = new util.ArrayList[Long](numPartitions)
+    for (partitionId <- 0 until numPartitions) {
+      sums.add(0)
+      val inputStream = shuffleClient.readPartition(
+        shuffleId,
+        partitionId,
+        attemptId,
+        0,
+        0,
+        Integer.MAX_VALUE,
+        metricsCallback)
+      var c = inputStream.read()
+      var data: Long = 0
+      var dataCnt = 0
+      while (c != -1) {
+        if (c == '-') {
+          sums.set(partitionId, sums.get(partitionId) + data)
+          data = 0
+          dataCnt += 1
+        } else {
+          assert(c >= '0' && c <= '9')
+          data *= 10
+          data += c - '0'
+        }
+        c = inputStream.read()
+      }
+      sums.set(partitionId, sums.get(partitionId) + data)
+      println(s"partition $partitionId sum result = ${sums.get(partitionId)}, 
dataCnt = $dataCnt")
+    }
+
+    // Verify the sum result.
+    var lineCount = 0
+    for (line <- Source.fromFile(cppResultFile, "utf-8").getLines.toList) {
+      val data = line.toLong
+      Assert.assertEquals(data, sums.get(lineCount))
+      lineCount += 1
+    }
+    Assert.assertEquals(lineCount, numPartitions)
+    lifecycleManager.stop()
+    shuffleClient.shutdown()
+  }
+
 }
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala
index bc1961384..327754ed9 100644
--- 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala
@@ -19,7 +19,7 @@ package org.apache.celeborn.service.deploy.cluster
 
 import org.apache.celeborn.common.protocol.CompressionCodec
 
-object JavaWriteCppReadTestWithLZ4 extends JavaWriteCppReadTestBase {
+object JavaWriteCppReadTestWithLZ4 extends JavaCppHybridReadWriteTestBase {
 
   def main(args: Array[String]) = {
     testJavaWriteCppRead(CompressionCodec.LZ4)
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithNONE.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithNONE.scala
index a649f8350..18bb8a418 100644
--- 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithNONE.scala
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithNONE.scala
@@ -19,7 +19,7 @@ package org.apache.celeborn.service.deploy.cluster
 
 import org.apache.celeborn.common.protocol.CompressionCodec
 
-object JavaWriteCppReadTestWithNONE extends JavaWriteCppReadTestBase {
+object JavaWriteCppReadTestWithNONE extends JavaCppHybridReadWriteTestBase {
 
   def main(args: Array[String]) = {
     testJavaWriteCppRead(CompressionCodec.NONE)
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithZSTD.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithZSTD.scala
index f2ba2e769..de7cdf102 100644
--- 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithZSTD.scala
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithZSTD.scala
@@ -19,7 +19,7 @@ package org.apache.celeborn.service.deploy.cluster
 
 import org.apache.celeborn.common.protocol.CompressionCodec
 
-object JavaWriteCppReadTestWithZSTD extends JavaWriteCppReadTestBase {
+object JavaWriteCppReadTestWithZSTD extends JavaCppHybridReadWriteTestBase {
 
   def main(args: Array[String]) = {
     testJavaWriteCppRead(CompressionCodec.ZSTD)

Reply via email to