This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new e41ee2dc9 [CELEBORN-1721][CIP-12] Support HARD_SPLIT in PushMergedData
e41ee2dc9 is described below

commit e41ee2dc9b643f91ae8dca3512d6a5f2305a2a80
Author: jiang13021 <[email protected]>
AuthorDate: Fri Dec 6 09:20:36 2024 +0800

    [CELEBORN-1721][CIP-12] Support HARD_SPLIT in PushMergedData
    
    ### What changes were proposed in this pull request?
    As title.
    
    ### Why are the changes needed?
    
https://docs.google.com/document/d/1Jaix22vME0m1Q-JtTHF9WYsrsxBWBwzwmifcPxNQZHk/edit?tab=t.0#heading=h.iadpu3t4rywi
    (Thanks to cfmcgrady littlexyw ErikFang waitinfuture RexXiong FMX  for 
their efforts on the proposal)
    
    ### Does this PR introduce _any_ user-facing change?
    The response of pushMergedData has been modified, however, the changes are 
backward compatible.
    
    ### How was this patch tested?
    UT: org.apache.celeborn.service.deploy.cluster.PushMergedDataHardSplitSuite
    
    Closes #2924 from jiang13021/cip-12.
    
    Authored-by: jiang13021 <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../apache/celeborn/client/ShuffleClientImpl.java  | 133 +++--
 .../common/network/protocol/TransportMessage.java  |   2 +
 .../common/protocol/message/StatusCode.java        |  18 +-
 common/src/main/proto/TransportMessages.proto      |   8 +-
 .../common/protocol/message/ControlMessages.scala  |   6 +
 .../service/deploy/worker/PushDataHandler.scala    | 533 ++++++++++++++-------
 .../deploy/cluster/PushMergedDataSplitSuite.scala  | 162 +++++++
 7 files changed, 655 insertions(+), 207 deletions(-)

diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 224ac57aa..f1e6a57e3 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -30,6 +30,7 @@ import scala.reflect.ClassTag$;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
+import com.google.protobuf.InvalidProtocolBufferException;
 import io.netty.buffer.CompositeByteBuf;
 import io.netty.buffer.Unpooled;
 import org.apache.commons.lang3.StringUtils;
@@ -51,6 +52,7 @@ import 
org.apache.celeborn.common.network.client.TransportClientBootstrap;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
 import org.apache.celeborn.common.network.protocol.PushData;
 import org.apache.celeborn.common.network.protocol.PushMergedData;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
 import org.apache.celeborn.common.network.sasl.SaslClientBootstrap;
 import org.apache.celeborn.common.network.sasl.SaslCredentials;
 import org.apache.celeborn.common.network.server.BaseMessageHandler;
@@ -1438,9 +1440,64 @@ public class ShuffleClientImpl extends ShuffleClient {
         new RpcResponseCallback() {
           @Override
           public void onSuccess(ByteBuffer response) {
-            if (response.remaining() > 0) {
-              byte reason = response.get();
-              if (reason == StatusCode.HARD_SPLIT.getValue()) {
+            byte reason = response.get();
+            if (reason == StatusCode.HARD_SPLIT.getValue()) {
+              ArrayList<DataBatches.DataBatch> batchesNeedResubmit;
+              if (response.remaining() > 0) {
+                batchesNeedResubmit = new ArrayList<>();
+                PbPushMergedDataSplitPartitionInfo partitionInfo;
+                try {
+                  partitionInfo = 
TransportMessage.fromByteBuffer(response).getParsedPayload();
+                } catch (CelebornIOException | InvalidProtocolBufferException 
e) {
+                  callback.onFailure(
+                      new CelebornIOException("parse pushMergedData response 
failed", e));
+                  return;
+                }
+                List<Integer> splitPartitionIndexes = 
partitionInfo.getSplitPartitionIndexesList();
+                List<Integer> statusCodeList = 
partitionInfo.getStatusCodesList();
+                StringBuilder dataBatchReviveInfos = new StringBuilder();
+                for (int i = 0; i < splitPartitionIndexes.size(); i++) {
+                  int partitionIndex = splitPartitionIndexes.get(i);
+                  int batchId = batches.get(partitionIndex).batchId;
+                  dataBatchReviveInfos.append(
+                      String.format(
+                          "(batchId=%d, partitionId=%d, cause=%s)",
+                          batchId,
+                          partitionIds[partitionIndex],
+                          
StatusCode.fromValue(statusCodeList.get(i).byteValue())));
+                  if (statusCodeList.get(i) == 
StatusCode.SOFT_SPLIT.getValue()) {
+                    PartitionLocation loc = batches.get(i).loc;
+                    if (!newerPartitionLocationExists(
+                        reducePartitionMap.get(shuffleId), loc.getId(), 
loc.getEpoch(), false)) {
+                      ReviveRequest reviveRequest =
+                          new ReviveRequest(
+                              shuffleId,
+                              mapId,
+                              attemptId,
+                              loc.getId(),
+                              loc.getEpoch(),
+                              loc,
+                              StatusCode.SOFT_SPLIT);
+                      reviveManager.addRequest(reviveRequest);
+                    }
+                  } else {
+                    batchesNeedResubmit.add(batches.get(partitionIndex));
+                  }
+                }
+                logger.info(
+                    "Push merged data to {} partial success required for 
shuffle {} map {} attempt {} groupedBatch {}. split batches {}.",
+                    addressPair,
+                    shuffleId,
+                    mapId,
+                    attemptId,
+                    groupedBatchId,
+                    dataBatchReviveInfos);
+              } else {
+                // Workers that do not incorporate changes from [CELEBORN-1721]
+                // will respond with a status of HARD_SPLIT,
+                // but will not include a PbPushMergedDataSplitPartitionInfo.
+                // For backward compatibility, all batches must be resubmitted.
+                batchesNeedResubmit = batches;
                 logger.info(
                     "Push merged data to {} hard split required for shuffle {} 
map {} attempt {} partition {} groupedBatch {} batch {}.",
                     addressPair,
@@ -1450,10 +1507,14 @@ public class ShuffleClientImpl extends ShuffleClient {
                     Arrays.toString(partitionIds),
                     groupedBatchId,
                     Arrays.toString(batchIds));
-
+              }
+              if (batchesNeedResubmit.isEmpty()) {
+                pushState.onSuccess(hostPort);
+                callback.onSuccess(ByteBuffer.wrap(new byte[] 
{StatusCode.SOFT_SPLIT.getValue()}));
+              } else {
                 ReviveRequest[] requests =
                     addAndGetReviveRequests(
-                        shuffleId, mapId, attemptId, batches, 
StatusCode.HARD_SPLIT);
+                        shuffleId, mapId, attemptId, batchesNeedResubmit, 
StatusCode.HARD_SPLIT);
                 pushDataRetryPool.submit(
                     () ->
                         submitRetryPushMergedData(
@@ -1461,7 +1522,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                             shuffleId,
                             mapId,
                             attemptId,
-                            batches,
+                            batchesNeedResubmit,
                             StatusCode.HARD_SPLIT,
                             groupedBatchId,
                             requests,
@@ -1470,39 +1531,37 @@ public class ShuffleClientImpl extends ShuffleClient {
                                 + 
conf.clientRpcRequestPartitionLocationAskTimeout()
                                     .duration()
                                     .toMillis()));
-              } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue()) {
-                logger.debug(
-                    "Push merged data to {} primary congestion required for 
shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.",
-                    addressPair,
-                    shuffleId,
-                    mapId,
-                    attemptId,
-                    Arrays.toString(partitionIds),
-                    groupedBatchId,
-                    Arrays.toString(batchIds));
-                pushState.onCongestControl(hostPort);
-                callback.onSuccess(response);
-              } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue()) {
-                logger.debug(
-                    "Push merged data to {} replica congestion required for 
shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.",
-                    addressPair,
-                    shuffleId,
-                    mapId,
-                    attemptId,
-                    Arrays.toString(partitionIds),
-                    groupedBatchId,
-                    Arrays.toString(batchIds));
-                pushState.onCongestControl(hostPort);
-                callback.onSuccess(response);
-              } else {
-                // StageEnd.
-                response.rewind();
-                pushState.onSuccess(hostPort);
-                callback.onSuccess(response);
               }
-            } else {
-              pushState.onSuccess(hostPort);
+            } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue()) {
+              logger.debug(
+                  "Push merged data to {} primary congestion required for 
shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.",
+                  addressPair,
+                  shuffleId,
+                  mapId,
+                  attemptId,
+                  Arrays.toString(partitionIds),
+                  groupedBatchId,
+                  Arrays.toString(batchIds));
+              pushState.onCongestControl(hostPort);
               callback.onSuccess(response);
+            } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue()) {
+              logger.debug(
+                  "Push merged data to {} replica congestion required for 
shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.",
+                  addressPair,
+                  shuffleId,
+                  mapId,
+                  attemptId,
+                  Arrays.toString(partitionIds),
+                  groupedBatchId,
+                  Arrays.toString(batchIds));
+              pushState.onCongestControl(hostPort);
+              callback.onSuccess(response);
+            } else if (reason == StatusCode.MAP_ENDED.getValue()) {
+              pushState.onSuccess(hostPort);
+              callback.onSuccess(ByteBuffer.wrap(new byte[] 
{StatusCode.MAP_ENDED.getValue()}));
+            } else { // success
+              pushState.onSuccess(hostPort);
+              callback.onSuccess(ByteBuffer.wrap(new byte[] 
{StatusCode.SUCCESS.getValue()}));
             }
           }
 
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index 239874e8c..01a9a37f9 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -113,6 +113,8 @@ public class TransportMessage implements Serializable {
         return (T) PbSegmentStart.parseFrom(payload);
       case NOTIFY_REQUIRED_SEGMENT_VALUE:
         return (T) PbNotifyRequiredSegment.parseFrom(payload);
+      case PUSH_MERGED_DATA_SPLIT_PARTITION_INFO_VALUE:
+        return (T) PbPushMergedDataSplitPartitionInfo.parseFrom(payload);
       default:
         logger.error("Unexpected type {}", type);
     }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
 
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
index ca8655bab..086368ca1 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
@@ -17,6 +17,10 @@
 
 package org.apache.celeborn.common.protocol.message;
 
+import java.util.Arrays;
+import java.util.Map;
+import java.util.stream.Collectors;
+
 public enum StatusCode {
   // 1/0 Status
   SUCCESS(0),
@@ -84,7 +88,8 @@ public enum StatusCode {
   PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA(50),
   OPEN_STREAM_FAILED(51),
   SEGMENT_START_FAIL_REPLICA(52),
-  SEGMENT_START_FAIL_PRIMARY(53);
+  SEGMENT_START_FAIL_PRIMARY(53),
+  NO_SPLIT(54);
 
   private final byte value;
 
@@ -96,4 +101,15 @@ public enum StatusCode {
   public final byte getValue() {
     return value;
   }
+
+  private static final Map<Byte, StatusCode> lookup =
+      Arrays.stream(StatusCode.values()).collect(Collectors.toMap(i -> 
i.getValue(), i -> i));
+
+  public static StatusCode fromValue(byte value) {
+    StatusCode code = lookup.get(value);
+    if (code != null) {
+      return code;
+    }
+    throw new IllegalArgumentException("Unknown status code: " + value);
+  }
 }
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index 5c03ba291..d7b439a68 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -111,6 +111,7 @@ enum MessageType {
   BATCH_UNREGISTER_SHUFFLE_RESPONSE= 88;
   REVISE_LOST_SHUFFLES = 89;
   REVISE_LOST_SHUFFLES_RESPONSE = 90;
+  PUSH_MERGED_DATA_SPLIT_PARTITION_INFO = 91;
 }
 
 enum StreamType {
@@ -875,4 +876,9 @@ message PbReviseLostShuffles{
 message PbReviseLostShufflesResponse{
   bool success = 1;
   string message = 2;
-}
\ No newline at end of file
+}
+
+message PbPushMergedDataSplitPartitionInfo {
+  repeated int32 splitPartitionIndexes = 1;
+  repeated int32 statusCodes = 2;
+}
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index bd1a6a11e..3c9b22c54 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -602,6 +602,9 @@ object ControlMessages extends Logging {
         MessageType.REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_RESPONSE,
         pb.toByteArray)
 
+    case pb: PbPushMergedDataSplitPartitionInfo =>
+      new TransportMessage(MessageType.PUSH_MERGED_DATA_SPLIT_PARTITION_INFO, 
pb.toByteArray)
+
     case HeartbeatFromWorker(
           host,
           rpcPort,
@@ -1400,6 +1403,9 @@ object ControlMessages extends Logging {
 
       case REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_RESPONSE_VALUE =>
         
PbReportBarrierStageAttemptFailureResponse.parseFrom(message.getPayload)
+
+      case PUSH_MERGED_DATA_SPLIT_PARTITION_INFO_VALUE =>
+        PbPushMergedDataSplitPartitionInfo.parseFrom(message.getPayload)
     }
   }
 }
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index af6d4cc66..0890f8355 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -18,9 +18,11 @@
 package org.apache.celeborn.service.deploy.worker
 
 import java.nio.ByteBuffer
+import java.util
 import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor}
 import java.util.concurrent.atomic.{AtomicBoolean, AtomicIntegerArray}
 
+import scala.collection.mutable
 import scala.concurrent.{Await, Promise}
 import scala.concurrent.duration.Duration
 import scala.util.{Failure, Success, Try}
@@ -37,7 +39,7 @@ import 
org.apache.celeborn.common.network.client.{RpcResponseCallback, Transport
 import org.apache.celeborn.common.network.protocol.{Message, PushData, 
PushDataHandShake, PushMergedData, RegionFinish, RegionStart, RequestMessage, 
RpcFailure, RpcRequest, RpcResponse, TransportMessage}
 import org.apache.celeborn.common.network.protocol.Message.Type
 import org.apache.celeborn.common.network.server.BaseMessageHandler
-import org.apache.celeborn.common.protocol.{PartitionLocation, 
PartitionSplitMode, PartitionType, PbPushDataHandShake, PbRegionFinish, 
PbRegionStart, PbSegmentStart}
+import org.apache.celeborn.common.protocol.{PartitionLocation, 
PartitionSplitMode, PartitionType, PbPushDataHandShake, 
PbPushMergedDataSplitPartitionInfo, PbRegionFinish, PbRegionStart, 
PbSegmentStart}
 import org.apache.celeborn.common.protocol.PbPartitionLocation.Mode
 import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.unsafe.Platform
@@ -183,7 +185,7 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
 
     // Fetch real batchId from body will add more cost and no meaning for 
replicate.
     val doReplicate = location != null && location.hasPeer && isPrimary
-    val softSplit = new AtomicBoolean(false)
+    var softSplit = false
 
     if (location == null) {
       val (mapId, attemptId) = getMapAttempt(body)
@@ -248,7 +250,14 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
       return
     }
 
-    if (checkDiskFullAndSplit(fileWriter, isPrimary, softSplit, 
callbackWithTimer)) return
+    val splitStatus = checkDiskFullAndSplit(fileWriter, isPrimary)
+    if (splitStatus == StatusCode.HARD_SPLIT) {
+      workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
+      
callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+      return
+    } else if (splitStatus == StatusCode.SOFT_SPLIT) {
+      softSplit = true
+    }
 
     fileWriter.incrementPendingWrites()
 
@@ -261,7 +270,7 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
       fileWriter.decrementPendingWrites()
       return
     }
-    val writePromise = Promise[Unit]()
+    val writePromise = Promise[Array[StatusCode]]()
     // for primary, send data to replica
     if (doReplicate) {
       pushData.body().retain()
@@ -288,34 +297,39 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
           val wrappedCallback = new RpcResponseCallback() {
             override def onSuccess(response: ByteBuffer): Unit = {
               Try(Await.result(writePromise.future, Duration.Inf)) match {
-                case Success(_) =>
-                  if (response.remaining() > 0) {
-                    val resp = ByteBuffer.allocate(response.remaining())
-                    resp.put(response)
-                    resp.flip()
-                    callbackWithTimer.onSuccess(resp)
-                  } else if (softSplit.get()) {
-                    // TODO Currently if the worker is in soft split status, 
given the guess that the client
-                    // will fast stop pushing data to the worker, we won't 
return congest status. But
-                    // in the long term, especially if this issue could 
frequently happen, we may need to return
-                    // congest&softSplit status together
-                    callbackWithTimer.onSuccess(
-                      
ByteBuffer.wrap(Array[Byte](StatusCode.SOFT_SPLIT.getValue)))
+                case Success(result) =>
+                  if (result(0) != StatusCode.SUCCESS) {
+                    
callback.onSuccess(ByteBuffer.wrap(Array[Byte](result(0).getValue)))
                   } else {
-                    Option(CongestionController.instance()) match {
-                      case Some(congestionController) =>
-                        if (congestionController.isUserCongested(
-                            fileWriter.getUserCongestionControlContext)) {
-                          // Check whether primary congest the data though the 
replicas doesn't congest
-                          // it(the response is empty)
-                          callbackWithTimer.onSuccess(
-                            ByteBuffer.wrap(
-                              
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
-                        } else {
+                    if (response.remaining() > 0) {
+                      val resp = ByteBuffer.allocate(response.remaining())
+                      resp.put(response)
+                      resp.flip()
+                      callbackWithTimer.onSuccess(resp)
+                    } else if (softSplit) {
+                      // TODO Currently if the worker is in soft split status, 
given the guess that the client
+                      // will fast stop pushing data to the worker, we won't 
return congest status. But
+                      // in the long term, especially if this issue could 
frequently happen, we may need to return
+                      // congest&softSplit status together
+                      callbackWithTimer.onSuccess(
+                        
ByteBuffer.wrap(Array[Byte](StatusCode.SOFT_SPLIT.getValue)))
+                    } else {
+                      Option(CongestionController.instance()) match {
+                        case Some(congestionController) =>
+                          if (congestionController.isUserCongested(
+                              fileWriter.getUserCongestionControlContext)) {
+                            // Check whether primary congest the data though 
the replicas doesn't congest
+                            // it(the response is empty)
+                            callbackWithTimer.onSuccess(
+                              ByteBuffer.wrap(
+                                Array[Byte](
+                                  
StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
+                          } else {
+                            
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+                          }
+                        case None =>
                           
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
-                        }
-                      case None =>
-                        
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+                      }
                     }
                   }
                 case Failure(e) => callbackWithTimer.onFailure(e)
@@ -376,29 +390,33 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
       // congest&softSplit status together
       writeLocalData(Seq(fileWriter), body, shuffleKey, isPrimary, None, 
writePromise)
       Try(Await.result(writePromise.future, Duration.Inf)) match {
-        case Success(_) =>
-          if (softSplit.get()) {
-            callbackWithTimer.onSuccess(
-              ByteBuffer.wrap(Array[Byte](StatusCode.SOFT_SPLIT.getValue)))
+        case Success(result) =>
+          if (result(0) != StatusCode.SUCCESS) {
+            
callback.onSuccess(ByteBuffer.wrap(Array[Byte](result(0).getValue)))
           } else {
-            Option(CongestionController.instance()) match {
-              case Some(congestionController) =>
-                if (congestionController.isUserCongested(
-                    fileWriter.getUserCongestionControlContext)) {
-                  if (isPrimary) {
-                    callbackWithTimer.onSuccess(
-                      ByteBuffer.wrap(
-                        
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
+            if (softSplit) {
+              callbackWithTimer.onSuccess(
+                ByteBuffer.wrap(Array[Byte](StatusCode.SOFT_SPLIT.getValue)))
+            } else {
+              Option(CongestionController.instance()) match {
+                case Some(congestionController) =>
+                  if (congestionController.isUserCongested(
+                      fileWriter.getUserCongestionControlContext)) {
+                    if (isPrimary) {
+                      callbackWithTimer.onSuccess(
+                        ByteBuffer.wrap(
+                          
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
+                    } else {
+                      callbackWithTimer.onSuccess(
+                        ByteBuffer.wrap(
+                          
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue)))
+                    }
                   } else {
-                    callbackWithTimer.onSuccess(
-                      ByteBuffer.wrap(
-                        
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue)))
+                    callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
                   }
-                } else {
+                case None =>
                   callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
-                }
-              case None =>
-                callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+              }
             }
           }
         case Failure(e) => callbackWithTimer.onFailure(e)
@@ -414,6 +432,7 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
     val batchOffsets = pushMergedData.batchOffsets
     val body = pushMergedData.body.asInstanceOf[NettyManagedBuffer].getBuf
     val isPrimary = mode == PartitionLocation.Mode.PRIMARY
+    val (mapId, attemptId) = getMapAttempt(body)
 
     val key = s"${pushMergedData.requestId}"
     val callbackWithTimer =
@@ -430,6 +449,7 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
           key,
           callback)
       }
+    val pushMergedDataCallback = new PushMergedDataCallback(callbackWithTimer)
 
     // For test
     if (isPrimary && testPushPrimaryDataTimeout &&
@@ -458,7 +478,6 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
     while (index < partitionIdToLocations.length) {
       val (id, loc) = partitionIdToLocations(index)
       if (loc == null) {
-        val (mapId, attemptId) = getMapAttempt(body)
         // MapperAttempts for a shuffle exists after any CommitFiles request 
succeeds.
         // A shuffle can trigger multiple CommitFiles requests, for reasons 
like: HARD_SPLIT happens, StageEnd.
         // If MapperAttempts but the value is -1 for the mapId(-1 means the 
map has not yet finished),
@@ -468,14 +487,13 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
             logDebug(s"Receive push merged data from speculative " +
               s"task(shuffle $shuffleKey, map $mapId, attempt $attemptId), " +
               s"but this mapper has already been ended.")
-            callbackWithTimer.onSuccess(
-              ByteBuffer.wrap(Array[Byte](StatusCode.MAP_ENDED.getValue)))
+            pushMergedDataCallback.onSuccess(StatusCode.MAP_ENDED)
+            return
           } else {
             logDebug(s"[Case1] Receive push merged data for committed hard 
split partition of " +
               s"(shuffle $shuffleKey, map $mapId attempt $attemptId)")
             workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
-            callbackWithTimer.onSuccess(
-              ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+            pushMergedDataCallback.addSplitPartition(index, 
StatusCode.HARD_SPLIT)
           }
         } else {
           if (storageManager.shuffleKeySet().contains(shuffleKey)) {
@@ -485,16 +503,15 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
             logDebug(s"[Case2] Receive push merged data for committed hard 
split partition of " +
               s"(shuffle $shuffleKey, map $mapId attempt $attemptId)")
             workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
-            callbackWithTimer.onSuccess(
-              ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+            pushMergedDataCallback.addSplitPartition(index, 
StatusCode.HARD_SPLIT)
           } else {
             logWarning(s"While handling PushMergedData, Partition location 
wasn't found for " +
               s"task(shuffle $shuffleKey, map $mapId, attempt $attemptId, 
uniqueId $id).")
-            callbackWithTimer.onFailure(
+            pushMergedDataCallback.onFailure(
               new 
CelebornIOException(StatusCode.PUSH_DATA_FAIL_PARTITION_NOT_FOUND))
+            return
           }
         }
-        return
       }
       index += 1
     }
@@ -502,7 +519,9 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
     // During worker shutdown, worker will return HARD_SPLIT for all existed 
partition.
     // This should before return exception to make current push data can 
revive and retry.
     if (shutdown.get()) {
-      
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+      partitionIdToLocations.indices.foreach(index =>
+        pushMergedDataCallback.addSplitPartition(index, StatusCode.HARD_SPLIT))
+      pushMergedDataCallback.onSuccess(StatusCode.HARD_SPLIT)
       return
     }
 
@@ -519,31 +538,48 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
         s"While handling PushMergedData, throw $cause, fileWriter 
$fileWriterWithException has exception.",
         fileWriterWithException.getException)
       workerSource.incCounter(WorkerSource.WRITE_DATA_FAIL_COUNT)
-      callbackWithTimer.onFailure(new CelebornIOException(cause))
+      pushMergedDataCallback.onFailure(new CelebornIOException(cause))
       return
     }
 
-    if (fileWriters.exists(checkDiskFull(_) == true)) {
-      val (mapId, attemptId) = getMapAttempt(body)
-      logWarning(
-        s"return hard split for disk full with shuffle $shuffleKey map $mapId 
attempt $attemptId")
-      
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
-      return
+    var fileWriterIndex = 0
+    val totalFileWriters = fileWriters.length
+    while (fileWriterIndex < totalFileWriters) {
+      val fileWriter = fileWriters(fileWriterIndex)
+      if (fileWriter == null) {
+        if (!pushMergedDataCallback.isHardSplitPartition(fileWriterIndex)) {
+          pushMergedDataCallback.onFailure(
+            new CelebornIOException(s"Partition $fileWriterIndex's fileWriter 
not found," +
+              s" but it hasn't been identified in the previous validation 
step."))
+          return
+        }
+      } else {
+        if (fileWriter.isClosed) {
+          val fileInfo = fileWriter.getCurrentFileInfo
+          logWarning(
+            s"[handlePushMergedData] FileWriter is already closed! File path 
${fileInfo.getFilePath} " +
+              s"length ${fileInfo.getFileLength}")
+          pushMergedDataCallback.addSplitPartition(fileWriterIndex, 
StatusCode.HARD_SPLIT)
+        } else {
+          val splitStatus = checkDiskFullAndSplit(fileWriter, isPrimary)
+          if (splitStatus == StatusCode.HARD_SPLIT) {
+            logWarning(
+              s"return hard split for disk full with shuffle $shuffleKey map 
$mapId attempt $attemptId")
+            workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
+            pushMergedDataCallback.addSplitPartition(fileWriterIndex, 
StatusCode.HARD_SPLIT)
+          } else if (splitStatus == StatusCode.SOFT_SPLIT) {
+            pushMergedDataCallback.addSplitPartition(fileWriterIndex, 
StatusCode.SOFT_SPLIT)
+          }
+        }
+        if (!pushMergedDataCallback.isHardSplitPartition(fileWriterIndex)) {
+          fileWriter.incrementPendingWrites()
+        }
+      }
+      fileWriterIndex += 1
     }
 
-    fileWriters.foreach(_.incrementPendingWrites())
-
-    val closedFileWriter = fileWriters.find(_.isClosed)
-    if (closedFileWriter.isDefined) {
-      val fileInfo = closedFileWriter.get.getCurrentFileInfo
-      logWarning(
-        s"[handlePushMergedData] FileWriter is already closed! File path 
${fileInfo.getFilePath} " +
-          s"length ${fileInfo.getFileLength}")
-      callbackWithTimer.onFailure(new CelebornIOException("File already 
closed!"))
-      fileWriters.foreach(_.decrementPendingWrites())
-      return
-    }
-    val writePromise = Promise[Unit]()
+    val hardSplitIndexes = pushMergedDataCallback.getHardSplitIndexes
+    val writePromise = Promise[Array[StatusCode]]()
     // for primary, send data to replica
     if (doReplicate) {
       pushMergedData.body().retain()
@@ -562,7 +598,7 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
             
workerSource.incCounter(WorkerSource.REPLICATE_DATA_CREATE_CONNECTION_FAIL_COUNT)
             logError(
               s"PushMergedData replication failed caused by unavailable peer 
for partitionLocation: $location")
-            callbackWithTimer.onFailure(
+            pushMergedDataCallback.onFailure(
               new 
CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_REPLICA))
             return
           }
@@ -570,30 +606,71 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
           // Handle the response from replica
           val wrappedCallback = new RpcResponseCallback() {
             override def onSuccess(response: ByteBuffer): Unit = {
+              val replicaReason = response.get()
+              if (replicaReason == StatusCode.HARD_SPLIT.getValue) {
+                if (response.remaining() > 0) {
+                  try {
+                    val pushMergedDataResponse: 
PbPushMergedDataSplitPartitionInfo =
+                      TransportMessage.fromByteBuffer(
+                        
response).getParsedPayload[PbPushMergedDataSplitPartitionInfo]()
+                    pushMergedDataCallback.unionReplicaSplitPartitions(
+                      pushMergedDataResponse.getSplitPartitionIndexesList,
+                      pushMergedDataResponse.getStatusCodesList)
+                  } catch {
+                    case e: CelebornIOException =>
+                      pushMergedDataCallback.onFailure(e)
+                      return
+                    case e: IllegalArgumentException =>
+                      pushMergedDataCallback.onFailure(new 
CelebornIOException(e))
+                      return
+                  }
+                } else {
+                  // During the rolling upgrade of the worker cluster, it is 
possible for the primary worker
+                  // to be upgraded to a new version that includes the changes 
from [CELEBORN-1721], while
+                  // the replica worker is still running on an older version 
that does not have these changes.
+                  // In this scenario, the replica may return a response with 
a status of HARD_SPLIT, but
+                  // will not provide a PbPushMergedDataSplitPartitionInfo.
+                  logWarning(
+                    s"The response status from the replica (shuffle 
$shuffleKey map $mapId attempt $attemptId) is HARD_SPLIT, but no 
PbPushMergedDataSplitPartitionInfo is present.")
+                  partitionIdToLocations.indices.foreach(index =>
+                    pushMergedDataCallback.addSplitPartition(index, 
StatusCode.HARD_SPLIT))
+                  pushMergedDataCallback.onSuccess(StatusCode.HARD_SPLIT)
+                  return
+                }
+              }
               Try(Await.result(writePromise.future, Duration.Inf)) match {
-                case Success(_) =>
+                case Success(result) =>
+                  var index = 0
+                  while (index < result.length) {
+                    if (result(index) == StatusCode.HARD_SPLIT) {
+                      pushMergedDataCallback.addSplitPartition(index, 
result(index))
+                    }
+                    index += 1
+                  }
                   // Only primary data enable replication will push data to 
replica
-                  if (response.remaining() > 0) {
-                    val resp = ByteBuffer.allocate(response.remaining())
-                    resp.put(response)
-                    resp.flip()
-                    callbackWithTimer.onSuccess(resp)
-                  } else {
-                    Option(CongestionController.instance()) match {
-                      case Some(congestionController) if fileWriters.nonEmpty 
=>
-                        if (congestionController.isUserCongested(
-                            fileWriters.head.getUserCongestionControlContext)) 
{
-                          // Check whether primary congest the data though the 
replicas doesn't congest
-                          // it(the response is empty)
-                          callbackWithTimer.onSuccess(
-                            ByteBuffer.wrap(
-                              
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
+                  Option(CongestionController.instance()) match {
+                    case Some(congestionController) if fileWriters.nonEmpty =>
+                      if (congestionController.isUserCongested(
+                          fileWriters.head.getUserCongestionControlContext)) {
+                        // Check whether primary congest the data though the 
replicas doesn't congest
+                        // it(the response is empty)
+                        pushMergedDataCallback.onSuccess(
+                          StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED)
+                      } else {
+                        if (replicaReason == 
StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue) {
+                          pushMergedDataCallback.onSuccess(
+                            StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED)
                         } else {
-                          
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+                          pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
                         }
-                      case None =>
-                        
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
-                    }
+                      }
+                    case None =>
+                      if (replicaReason == 
StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue) {
+                        pushMergedDataCallback.onSuccess(
+                          StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED)
+                      } else {
+                        pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
+                      }
                   }
                 case Failure(e) => callbackWithTimer.onFailure(e)
               }
@@ -606,17 +683,17 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
               // 3. Throw IOException by channel, convert to 
PUSH_DATA_CONNECTION_EXCEPTION_REPLICA
               if 
(e.getMessage.startsWith(StatusCode.PUSH_DATA_WRITE_FAIL_REPLICA.name())) {
                 
workerSource.incCounter(WorkerSource.REPLICATE_DATA_WRITE_FAIL_COUNT)
-                callbackWithTimer.onFailure(e)
+                pushMergedDataCallback.onFailure(e)
               } else if 
(e.getMessage.startsWith(StatusCode.PUSH_DATA_TIMEOUT_REPLICA.name())) {
                 
workerSource.incCounter(WorkerSource.REPLICATE_DATA_TIMEOUT_COUNT)
-                callbackWithTimer.onFailure(e)
+                pushMergedDataCallback.onFailure(e)
               } else if (ExceptionUtils.connectFail(e.getMessage)) {
                 
workerSource.incCounter(WorkerSource.REPLICATE_DATA_CONNECTION_EXCEPTION_COUNT)
-                callbackWithTimer.onFailure(
+                pushMergedDataCallback.onFailure(
                   new 
CelebornIOException(StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_REPLICA))
               } else {
                 
workerSource.incCounter(WorkerSource.REPLICATE_DATA_FAIL_NON_CRITICAL_CAUSE_COUNT)
-                callbackWithTimer.onFailure(
+                pushMergedDataCallback.onFailure(
                   new 
CelebornIOException(StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA))
               }
             }
@@ -642,39 +719,56 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
               logError(
                 s"PushMergedData replication failed during connecting peer for 
partitionLocation: $location",
                 e)
-              callbackWithTimer.onFailure(
+              pushMergedDataCallback.onFailure(
                 new 
CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_REPLICA))
           }
         }
       })
-      writeLocalData(fileWriters, body, shuffleKey, isPrimary, 
Some(batchOffsets), writePromise)
+      writeLocalData(
+        fileWriters,
+        body,
+        shuffleKey,
+        isPrimary,
+        Some(batchOffsets),
+        writePromise,
+        hardSplitIndexes)
     } else {
       // The codes here could be executed if
       // 1. the client doesn't enable push data to the replica, the primary 
worker could hit here
       // 2. the client enables push data to the replica, and the replica 
worker could hit here
-      writeLocalData(fileWriters, body, shuffleKey, isPrimary, 
Some(batchOffsets), writePromise)
+      writeLocalData(
+        fileWriters,
+        body,
+        shuffleKey,
+        isPrimary,
+        Some(batchOffsets),
+        writePromise,
+        hardSplitIndexes)
       Try(Await.result(writePromise.future, Duration.Inf)) match {
-        case Success(_) =>
+        case Success(result) =>
+          var index = 0
+          while (index < result.length) {
+            if (result(index) == StatusCode.HARD_SPLIT) {
+              pushMergedDataCallback.addSplitPartition(index, result(index))
+            }
+            index += 1
+          }
           Option(CongestionController.instance()) match {
             case Some(congestionController) if fileWriters.nonEmpty =>
               if (congestionController.isUserCongested(
                   fileWriters.head.getUserCongestionControlContext)) {
                 if (isPrimary) {
-                  callbackWithTimer.onSuccess(
-                    ByteBuffer.wrap(
-                      
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
+                  
pushMergedDataCallback.onSuccess(StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED)
                 } else {
-                  callbackWithTimer.onSuccess(
-                    ByteBuffer.wrap(
-                      
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue)))
+                  
pushMergedDataCallback.onSuccess(StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED)
                 }
               } else {
-                callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+                pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
               }
             case None =>
-              callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+              pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
           }
-        case Failure(e) => callbackWithTimer.onFailure(e)
+        case Failure(e) => pushMergedDataCallback.onFailure(e)
       }
     }
   }
@@ -691,11 +785,15 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
     var exceptionFileWriterIndex: Option[Int] = None
     while (i < partitionIdToLocations.length) {
       val (_, workingPartition) = partitionIdToLocations(i)
-      val fileWriter = 
workingPartition.asInstanceOf[WorkingPartition].getFileWriter
-      if (fileWriter.getException != null) {
-        exceptionFileWriterIndex = Some(i)
+      if (workingPartition != null) {
+        val fileWriter = 
workingPartition.asInstanceOf[WorkingPartition].getFileWriter
+        if (fileWriter.getException != null) {
+          exceptionFileWriterIndex = Some(i)
+        }
+        fileWriters(i) = fileWriter
+      } else {
+        fileWriters(i) = null
       }
-      fileWriters(i) = fileWriter
       i += 1
     }
     (fileWriters, exceptionFileWriterIndex)
@@ -747,6 +845,76 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
     }
   }
 
+  class PushMergedDataCallback(callback: RpcResponseCallback) {
+    private val splitPartitionStatuses = new mutable.HashMap[Int, Byte]()
+
+    def addSplitPartition(index: Int, statusCode: StatusCode): Unit = {
+      splitPartitionStatuses.put(index, statusCode.getValue)
+    }
+
+    def isHardSplitPartition(index: Int): Boolean = {
+      splitPartitionStatuses.getOrElse(index, -1) == 
StatusCode.HARD_SPLIT.getValue
+    }
+
+    def unionReplicaSplitPartitions(
+        replicaPartitionIndexes: util.List[Integer],
+        replicaStatusCodes: util.List[Integer]): Unit = {
+      if (replicaPartitionIndexes.size() != replicaStatusCodes.size()) {
+        throw new IllegalArgumentException(
+          "replicaPartitionIndexes and replicaStatusCodes must have the same 
size")
+      }
+      for (i <- 0 until replicaPartitionIndexes.size()) {
+        val index = replicaPartitionIndexes.get(i)
+        // The priority of HARD_SPLIT is higher than that of SOFT_SPLIT.
+        if (!isHardSplitPartition(index)) {
+          splitPartitionStatuses.put(index, 
replicaStatusCodes.get(i).byteValue())
+        }
+      }
+    }
+
+    /**
+     * Returns the ordered indexes of partitions that are not writable.
+     * A partition is considered not writable if it is marked as HARD_SPLIT or 
failed.
+     */
+    def getHardSplitIndexes: Array[Int] = {
+      splitPartitionStatuses.collect {
+        case (partitionIndex, statusCode) if statusCode == 
StatusCode.HARD_SPLIT.getValue =>
+          partitionIndex
+      }.toSeq.sorted.toArray
+    }
+
+    def onSuccess(status: StatusCode): Unit = {
+      val splitPartitionIndexes = new util.ArrayList[Integer]()
+      val statusCodes = new util.ArrayList[Integer]()
+      splitPartitionStatuses.foreach {
+        case (partitionIndex, statusCode) =>
+          splitPartitionIndexes.add(partitionIndex)
+          statusCodes.add(statusCode)
+      }
+      if (splitPartitionStatuses.isEmpty || status == StatusCode.MAP_ENDED) {
+        callback.onSuccess(
+          ByteBuffer.wrap(Array[Byte](status.getValue)))
+      } else {
+        val pushMergedDataInfo = 
PbPushMergedDataSplitPartitionInfo.newBuilder()
+          .addAllSplitPartitionIndexes(splitPartitionIndexes)
+          .addAllStatusCodes(statusCodes)
+          .build()
+        val pushMergedDataInfoByteBuffer = 
Utils.toTransportMessage(pushMergedDataInfo)
+          .asInstanceOf[TransportMessage]
+          .toByteBuffer
+        val response = ByteBuffer.allocate(1 + 
pushMergedDataInfoByteBuffer.remaining())
+        response.put(StatusCode.HARD_SPLIT.getValue)
+        response.put(pushMergedDataInfoByteBuffer)
+        response.flip()
+        callback.onSuccess(response)
+      }
+    }
+
+    def onFailure(exception: Throwable): Unit = {
+      callback.onFailure(exception)
+    }
+  }
+
   private def handleCore(
       client: TransportClient,
       message: RequestMessage,
@@ -821,18 +989,28 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
       fileWriter.decrementPendingWrites()
       return
     }
-    val writePromise = Promise[Unit]()
+    val writePromise = Promise[Array[StatusCode]]()
     writeLocalData(Seq(fileWriter), body, shuffleKey, isPrimary, None, 
writePromise)
     // for primary, send data to replica
     if (location.hasPeer && isPrimary) {
       // to do
       Try(Await.result(writePromise.future, Duration.Inf)) match {
-        case Success(_) => 
wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+        case Success(result) =>
+          if (result(0) != StatusCode.SUCCESS) {
+            wrappedCallback.onFailure(new CelebornIOException("Write data 
failed!"))
+          } else {
+            wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+          }
         case Failure(e) => wrappedCallback.onFailure(e)
       }
     } else {
       Try(Await.result(writePromise.future, Duration.Inf)) match {
-        case Success(_) => 
wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+        case Success(result) =>
+          if (result(0) != StatusCode.SUCCESS) {
+            wrappedCallback.onFailure(new CelebornIOException("Write data 
failed!"))
+          } else {
+            wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+          }
         case Failure(e) => wrappedCallback.onFailure(e)
       }
     }
@@ -1030,9 +1208,11 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
     if (checkSplit && (messageType == Type.REGION_START || messageType ==
         Type.PUSH_DATA_HAND_SHAKE) && isPartitionSplitEnabled && 
checkDiskFullAndSplit(
         fileWriter,
-        isPrimary,
-        null,
-        callback)) return
+        isPrimary) == StatusCode.HARD_SPLIT) {
+      workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
+      
callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+      return
+    }
 
     try {
       messageType match {
@@ -1226,9 +1406,12 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
 
   private def checkDiskFullAndSplit(
       fileWriter: PartitionDataWriter,
-      isPrimary: Boolean,
-      softSplit: AtomicBoolean,
-      callback: RpcResponseCallback): Boolean = {
+      isPrimary: Boolean): StatusCode = {
+    if (fileWriter.needHardSplitForMemoryShuffleStorage()) {
+      logInfo(
+        s"Do hardSplit for memory shuffle file 
fileLength:${fileWriter.getMemoryFileInfo.getFileLength}")
+      return StatusCode.HARD_SPLIT
+    }
     val diskFull = checkDiskFull(fileWriter)
     logTrace(
       s"""
@@ -1239,24 +1422,14 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
          |fileLength:${fileWriter.getCurrentFileInfo.getFileLength}
          |fileName:${fileWriter.getCurrentFileInfo.getFilePath}
          |""".stripMargin)
-    if (fileWriter.needHardSplitForMemoryShuffleStorage()) {
-      workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
-      
callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
-      logInfo(
-        s"Do hardSplit for memory shuffle file 
fileLength:${fileWriter.getMemoryFileInfo.getFileLength}")
-      return true
-    }
-
     val diskFileInfo = fileWriter.getDiskFileInfo
     if (diskFileInfo != null) {
       if (workerPartitionSplitEnabled && ((diskFull && 
diskFileInfo.getFileLength > partitionSplitMinimumSize) ||
           (isPrimary && diskFileInfo.getFileLength > 
fileWriter.getSplitThreshold))) {
-        if (softSplit != null && fileWriter.getSplitMode == 
PartitionSplitMode.SOFT &&
+        if (fileWriter.getSplitMode == PartitionSplitMode.SOFT &&
           (fileWriter.getDiskFileInfo.getFileLength < 
partitionSplitMaximumSize)) {
-          softSplit.set(true)
+          return StatusCode.SOFT_SPLIT
         } else {
-          workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
-          
callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
           logInfo(
             s"""
                |CheckDiskFullAndSplit hardSplit
@@ -1266,11 +1439,11 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
                |fileLength:${diskFileInfo.getFileLength},
                |fileName:${diskFileInfo.getFilePath}
                |""".stripMargin)
-          return true
+          return StatusCode.HARD_SPLIT
         }
       }
     }
-    false
+    StatusCode.NO_SPLIT
   }
 
   private def getReplicateClient(host: String, port: Int, partitionId: Int): 
TransportClient = {
@@ -1287,10 +1460,18 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
       shuffleKey: String,
       isPrimary: Boolean,
       batchOffsets: Option[Array[Int]],
-      writePromise: Promise[Unit]): Unit = {
-    def writeData(fileWriter: PartitionDataWriter, body: ByteBuf, shuffleKey: 
String): Unit = {
+      writePromise: Promise[Array[StatusCode]],
+      hardSplitIndexes: Array[Int] = Array.empty[Int]): Unit = {
+    val length = fileWriters.length
+    val result = new Array[StatusCode](length)
+    def writeData(
+        fileWriter: PartitionDataWriter,
+        body: ByteBuf,
+        shuffleKey: String,
+        index: Int): Unit = {
       try {
         fileWriter.write(body)
+        result(index) = StatusCode.SUCCESS
       } catch {
         case e: Exception =>
           if (e.isInstanceOf[AlreadyClosedException]) {
@@ -1302,47 +1483,63 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
             // TODO just info log for ended attempt
             logWarning(s"Append data failed for task(shuffle $shuffleKey, map 
$mapId, attempt" +
               s" $attemptId), caused by AlreadyClosedException, endedAttempt 
$endedAttempt, error message: ${e.getMessage}")
+            workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
+            result(index) = StatusCode.HARD_SPLIT
           } else {
             logError("Exception encountered when write.", e)
+            workerSource.incCounter(WorkerSource.WRITE_DATA_FAIL_COUNT)
+            val cause =
+              if (isPrimary) {
+                StatusCode.PUSH_DATA_WRITE_FAIL_PRIMARY
+              } else {
+                StatusCode.PUSH_DATA_WRITE_FAIL_REPLICA
+              }
+            writePromise.failure(new CelebornIOException(cause))
           }
-          val cause =
-            if (isPrimary) {
-              StatusCode.PUSH_DATA_WRITE_FAIL_PRIMARY
-            } else {
-              StatusCode.PUSH_DATA_WRITE_FAIL_REPLICA
-            }
-          workerSource.incCounter(WorkerSource.WRITE_DATA_FAIL_COUNT)
-          writePromise.failure(new CelebornIOException(cause))
           fileWriter.decrementPendingWrites()
       }
     }
     batchOffsets match {
       case Some(batchOffsets) =>
         var index = 0
+        val hardSplitIterator = hardSplitIndexes.iterator
+        var currentHardSplitIndex = nextValueOrElse(hardSplitIterator, -1)
         var fileWriter: PartitionDataWriter = null
         while (index < fileWriters.length) {
-          fileWriter = fileWriters(index)
-          if (!writePromise.isCompleted) {
-            val offset = body.readerIndex() + batchOffsets(index)
-            val length =
-              if (index == fileWriters.length - 1) {
-                body.readableBytes() - batchOffsets(index)
-              } else {
-                batchOffsets(index + 1) - batchOffsets(index)
-              }
-            val batchBody = body.slice(offset, length)
-            writeData(fileWriter, batchBody, shuffleKey)
+          if (index == currentHardSplitIndex) {
+            currentHardSplitIndex = nextValueOrElse(hardSplitIterator, -1)
           } else {
-            fileWriter.decrementPendingWrites()
+            fileWriter = fileWriters(index)
+            if (!writePromise.isCompleted) {
+              val offset = body.readerIndex() + batchOffsets(index)
+              val length =
+                if (index == fileWriters.length - 1) {
+                  body.readableBytes() - batchOffsets(index)
+                } else {
+                  batchOffsets(index + 1) - batchOffsets(index)
+                }
+              val batchBody = body.slice(offset, length)
+              writeData(fileWriter, batchBody, shuffleKey, index)
+            } else {
+              fileWriter.decrementPendingWrites()
+            }
           }
           index += 1
         }
       case _ =>
-        writeData(fileWriters.head, body, shuffleKey)
+        writeData(fileWriters.head, body, shuffleKey, 0)
     }
     if (!writePromise.isCompleted) {
       workerSource.incCounter(WorkerSource.WRITE_DATA_SUCCESS_COUNT)
-      writePromise.success()
+      writePromise.success(result)
+    }
+  }
+
+  private def nextValueOrElse(iterator: Iterator[Int], defaultValue: Int): Int 
= {
+    if (iterator.hasNext) {
+      iterator.next()
+    } else {
+      defaultValue
     }
   }
 
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala
new file mode 100644
index 000000000..cdd0e758e
--- /dev/null
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.cluster
+
+import java.nio.charset.StandardCharsets
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.commons.lang3.RandomStringUtils
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl}
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.identity.UserIdentifier
+import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.meta.WorkerInfo
+import org.apache.celeborn.service.deploy.MiniClusterFeature
+
+class PushMergedDataSplitSuite extends AnyFunSuite
+  with Logging with MiniClusterFeature with BeforeAndAfterAll {
+
+  var masterEndpoint = ""
+  override def beforeAll(): Unit = {
+    val conf = Map("celeborn.worker.flusher.buffer.size" -> "0")
+
+    logInfo("test initialized , setup Celeborn mini cluster")
+    val (master, _) = setupMiniClusterWithRandomPorts(conf, conf, 2)
+    masterEndpoint = master.conf.get(CelebornConf.MASTER_ENDPOINTS.key)
+  }
+
+  override def afterAll(): Unit = {
+    logInfo("all test complete , stop Celeborn mini cluster")
+    super.shutdownMiniCluster()
+  }
+
+  test("push merged data and partial partition are split") {
+    val SHUFFLE_ID = 0
+    val MAP_ID = 0
+    val ATTEMPT_ID = 0
+    val MAP_NUM = 1
+    val PARTITION_NUM = 3
+
+    Array("SOFT", "HARD").foreach {
+      splitMode =>
+        val APP = s"app-${System.currentTimeMillis()}"
+        val clientConf = new CelebornConf()
+          .set(CelebornConf.MASTER_ENDPOINTS.key, masterEndpoint)
+          .set(CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key, "1")
+          .set(CelebornConf.SHUFFLE_PARTITION_SPLIT_THRESHOLD.key, "5K")
+          .set(CelebornConf.SHUFFLE_PARTITION_SPLIT_MODE.key, splitMode)
+        val lifecycleManager = new LifecycleManager(APP, clientConf)
+        val shuffleClient = new ShuffleClientImpl(APP, clientConf, 
UserIdentifier("mock", "mock"))
+        shuffleClient.setupLifecycleManagerRef(lifecycleManager.self)
+
+        // ping and reserveSlots
+        val DATA0 = 
RandomStringUtils.secure().next(10).getBytes(StandardCharsets.UTF_8)
+        shuffleClient.pushData(
+          SHUFFLE_ID,
+          MAP_ID,
+          ATTEMPT_ID,
+          0,
+          DATA0,
+          0,
+          DATA0.length,
+          MAP_NUM,
+          PARTITION_NUM)
+
+        // find the worker that has at least 2 partitions
+        val partitionLocationMap =
+          shuffleClient.getPartitionLocation(SHUFFLE_ID, MAP_NUM, 
PARTITION_NUM)
+        val worker2PartitionIds = mutable.Map.empty[WorkerInfo, 
ArrayBuffer[Int]]
+        for (partitionId <- 0 until PARTITION_NUM) {
+          val partitionLocation = partitionLocationMap.get(partitionId)
+          worker2PartitionIds
+            .getOrElseUpdate(partitionLocation.getWorker, ArrayBuffer.empty)
+            .append(partitionId)
+        }
+        val partitions = worker2PartitionIds.values.filter(_.size >= 2).head
+        assert(partitions.length >= 2)
+
+        // prepare merged data
+        val PARTITION0_DATA = 
RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8)
+        shuffleClient.mergeData(
+          SHUFFLE_ID,
+          MAP_ID,
+          ATTEMPT_ID,
+          partitions(0),
+          PARTITION0_DATA,
+          0,
+          PARTITION0_DATA.length,
+          MAP_NUM,
+          PARTITION_NUM)
+
+        val PARTITION1_DATA = 
RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8)
+        shuffleClient.mergeData(
+          SHUFFLE_ID,
+          MAP_ID,
+          ATTEMPT_ID,
+          partitions(1),
+          PARTITION1_DATA,
+          0,
+          PARTITION1_DATA.length,
+          MAP_NUM,
+          PARTITION_NUM)
+
+        // pushData until partition(0) is split
+        val GIANT_DATA =
+          RandomStringUtils.secure().next(1024 * 
100).getBytes(StandardCharsets.UTF_8)
+        shuffleClient.pushData(
+          SHUFFLE_ID,
+          MAP_ID,
+          ATTEMPT_ID,
+          partitions(0),
+          GIANT_DATA,
+          0,
+          GIANT_DATA.length,
+          MAP_NUM,
+          PARTITION_NUM)
+        for (_ <- 0 until 5) {
+          val TRIGGER_DATA = 
RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8)
+          shuffleClient.pushData(
+            SHUFFLE_ID,
+            MAP_ID,
+            ATTEMPT_ID,
+            partitions(0),
+            TRIGGER_DATA,
+            0,
+            TRIGGER_DATA.length,
+            MAP_NUM,
+            PARTITION_NUM)
+          Thread.sleep(5 * 1000) // wait for flush
+        }
+        assert(
+          partitionLocationMap.get(partitions(0)).getEpoch > 0
+        ) // means partition(0) will be split
+
+        // push merged data, we expect that partition(0) will be split, while 
partition(1) will not be split
+        shuffleClient.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID)
+        shuffleClient.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM)
+        assert(
+          partitionLocationMap.get(partitions(1)).getEpoch == 0
+        ) // means partition(1) will not be split
+    }
+  }
+}

Reply via email to