This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch test
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git

commit 540912ed512321f852fd75e14887c36fa2081a6a
Author: zky.zhoukeyong <[email protected]>
AuthorDate: Fri Jun 9 19:43:39 2023 +0800

    refine revive
---
 .../apache/celeborn/client/ShuffleClientImpl.java  | 158 +++++++++++++--------
 .../celeborn/client/ChangePartitionManager.scala   |   7 +-
 .../apache/celeborn/client/LifecycleManager.scala  |  79 +++++++----
 .../client/RequestLocationCallContext.scala        |  35 ++++-
 .../common/protocol/PartitionLocation.java         |   2 +
 common/src/main/proto/TransportMessages.proto      |  19 +++
 .../common/protocol/message/ControlMessages.scala  |  55 +++++++
 .../service/deploy/worker/PushDataHandler.scala    |   5 +-
 8 files changed, 266 insertions(+), 94 deletions(-)

diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index e32dd1d8d..f2f3de2b3 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -175,7 +175,7 @@ public class ShuffleClientImpl extends ShuffleClient {
       pushState.removeBatch(batchId);
     } else {
       PartitionLocation newLoc = 
reducePartitionMap.get(shuffleId).get(partitionId);
-      logger.info("Revive success, new location for reduce {} is {}.", 
partitionId, newLoc);
+      logger.debug("Revive success, new location for reduce {} is {}.", 
partitionId, newLoc);
       try {
         TransportClient client =
             dataClientFactory.createClient(newLoc.getHost(), 
newLoc.getPushPort(), partitionId);
@@ -209,27 +209,34 @@ public class ShuffleClientImpl extends ShuffleClient {
       StatusCode cause,
       Integer oldGroupedBatchId) {
     HashMap<String, DataBatches> newDataBatchesMap = new HashMap<>();
-    for (DataBatches.DataBatch batch : batches) {
-      int partitionId = batch.loc.getId();
-      if (!revive(
-          applicationId,
-          shuffleId,
-          mapId,
-          attemptId,
-          partitionId,
-          batch.loc.getEpoch(),
-          batch.loc,
-          cause)) {
-        pushState.exception.compareAndSet(
-            null,
-            new IOException("Revive Failed in retry push merged data for 
location: " + batch.loc));
-        return;
-      } else if (mapperEnded(shuffleId, mapId, attemptId)) {
-        logger.debug(
-            "Retrying push data, but the mapper(map {} attempt {}) has 
ended.", mapId, attemptId);
-      } else {
-        PartitionLocation newLoc = 
reducePartitionMap.get(shuffleId).get(partitionId);
-        logger.info("Revive success, new location for reduce {} is {}.", 
partitionId, newLoc);
+    int[] ids = new int[batches.size()];
+    int[] epoches = new int[batches.size()];
+    PartitionLocation[] locs = new PartitionLocation[batches.size()];
+    StatusCode[] causes = new StatusCode[batches.size()];
+
+    for (int i = 0; i < batches.size(); i++) {
+      DataBatches.DataBatch batch = batches.get(i);
+      ids[i] = batch.loc.getId();
+      epoches[i] = batch.loc.getEpoch();
+      locs[i] = batch.loc;
+      causes[i] = cause;
+    }
+
+    boolean reviveSuccess =
+        reviveBatch(applicationId, shuffleId, mapId, attemptId, ids, epoches, 
locs, causes);
+    if (!reviveSuccess) {
+      pushState.exception.compareAndSet(
+          null,
+          new IOException(
+              "Revive Failed in retry push merged data for location: " + 
locs[0].getHost()));
+      return;
+    } else if (mapperEnded(shuffleId, mapId, attemptId)) {
+      logger.debug(
+          "Retrying push data, but the mapper(map {} attempt {}) has ended.", 
mapId, attemptId);
+    } else {
+      for (int i = 0; i < batches.size(); i++) {
+        DataBatches.DataBatch batch = batches.get(i);
+        PartitionLocation newLoc = 
reducePartitionMap.get(shuffleId).get(batch.loc.getId());
         DataBatches newDataBatches =
             newDataBatchesMap.computeIfAbsent(genAddressPair(newLoc), (s) -> 
new DataBatches());
         newDataBatches.addDataBatch(newLoc, batch.batchId, batch.body);
@@ -395,20 +402,22 @@ public class ShuffleClientImpl extends ShuffleClient {
     PartitionLocation currentLocation = map.get(partitionId);
     if (currentLocation != null && currentLocation.getEpoch() > epoch) {
       return true;
+    } else {
+      return false;
     }
-
-    long sleepTimeMs = RND.nextInt(50);
-    if (sleepTimeMs > 30) {
-      try {
-        TimeUnit.MILLISECONDS.sleep(sleepTimeMs);
-      } catch (InterruptedException e) {
-        logger.warn("Wait revived location interrupted", e);
-        Thread.currentThread().interrupt();
-      }
-    }
-
-    currentLocation = map.get(partitionId);
-    return currentLocation != null && currentLocation.getEpoch() > epoch;
+//
+//    long sleepTimeMs = RND.nextInt(50);
+//    if (sleepTimeMs > 30) {
+//      try {
+//        TimeUnit.MILLISECONDS.sleep(sleepTimeMs);
+//      } catch (InterruptedException e) {
+//        logger.warn("Wait revived location interrupted", e);
+//        Thread.currentThread().interrupt();
+//      }
+//    }
+
+//    currentLocation = map.get(partitionId);
+//    return currentLocation != null && currentLocation.getEpoch() > epoch;
   }
 
   private boolean revive(
@@ -431,6 +440,28 @@ public class ShuffleClientImpl extends ShuffleClient {
           epoch);
       return true;
     }
+
+    return reviveBatch(
+        applicationId,
+        shuffleId,
+        mapId,
+        attemptId,
+        new int[] {partitionId},
+        new int[] {epoch},
+        new PartitionLocation[] {oldLocation},
+        new StatusCode[] {cause});
+  }
+
+  private boolean reviveBatch(
+      String applicationId,
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int[] partitionIds,
+      int[] epoches,
+      PartitionLocation[] oldLocations,
+      StatusCode[] causes) {
+    ConcurrentHashMap<Integer, PartitionLocation> map = 
reducePartitionMap.get(shuffleId);
     String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
     if (mapperEnded(shuffleId, mapId, attemptId)) {
       logger.debug(
@@ -442,36 +473,45 @@ public class ShuffleClientImpl extends ShuffleClient {
     }
 
     try {
-      PbChangeLocationResponse response =
+      PbChangeLocationsResponse response =
           driverRssMetaService.askSync(
-              Revive$.MODULE$.apply(
+              ReviveBatch$.MODULE$.apply(
                   applicationId,
                   shuffleId,
                   mapId,
                   attemptId,
-                  partitionId,
-                  epoch,
-                  oldLocation,
-                  cause),
+                  partitionIds,
+                  epoches,
+                  oldLocations,
+                  causes),
               conf.requestPartitionLocationRpcAskTimeout(),
-              ClassTag$.MODULE$.apply(PbChangeLocationResponse.class));
-      // per partitionKey only serve single PartitionLocation in Client Cache.
-      StatusCode respStatus = Utils.toStatusCode(response.getStatus());
-      if (StatusCode.SUCCESS.equals(respStatus)) {
-        map.put(partitionId, 
PbSerDeUtils.fromPbPartitionLocation(response.getLocation()));
-        return true;
-      } else if (StatusCode.MAP_ENDED.equals(respStatus)) {
-        mapperEndMap.computeIfAbsent(shuffleId, (id) -> 
ConcurrentHashMap.newKeySet()).add(mapKey);
-        return true;
-      } else {
-        return false;
+              ClassTag$.MODULE$.apply(PbChangeLocationsResponse.class));
+      boolean allSuccess = true;
+      for (int i = 0; i < response.getIdsCount(); i++) {
+        if 
(StatusCode.SUCCESS.equals(Utils.toStatusCode(response.getStatus(i)))) {
+          map.put(
+              response.getIds(i), 
PbSerDeUtils.fromPbPartitionLocation(response.getLocation(i)));
+        } else {
+          logger.info("StatusCode not SUCCESS! {}", response.getStatus(i));
+          allSuccess = false;
+        }
       }
+      if (!allSuccess) {
+        StatusCode firstStatus = Utils.toStatusCode(response.getStatus(0));
+        if (StatusCode.MAP_ENDED.equals(firstStatus)) {
+          mapperEndMap
+              .computeIfAbsent(shuffleId, (id) -> 
ConcurrentHashMap.newKeySet())
+              .add(mapKey);
+          return true;
+        }
+      }
+      return allSuccess;
     } catch (Exception e) {
       logger.error(
           "Exception raised while reviving for shuffle {} reduce {} epoch {}.",
           shuffleId,
-          partitionId,
-          epoch,
+          partitionIds,
+          epoches,
           e);
       return false;
     }
@@ -645,7 +685,8 @@ public class ShuffleClientImpl extends ShuffleClient {
                       nextBatchId);
                   splitPartition(shuffleId, partitionId, applicationId, loc);
                   callback.onSuccess(response);
-                } else if (reason == StatusCode.HARD_SPLIT.getValue()) {
+                } else if (reason == StatusCode.HARD_SPLIT.getValue() ||
+                reason == StatusCode.WORKER_SHUTDOWN.getValue()) {
                   logger.debug(
                       "Push data split for map {} attempt {} batch {}.",
                       mapId,
@@ -663,7 +704,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                               loc,
                               this,
                               pushState,
-                              StatusCode.HARD_SPLIT));
+                              Utils.toStatusCode(reason)));
                 } else {
                   response.rewind();
                   callback.onSuccess(response);
@@ -944,7 +985,8 @@ public class ShuffleClientImpl extends ShuffleClient {
           public void onSuccess(ByteBuffer response) {
             if (response.remaining() > 0) {
               byte reason = response.get();
-              if (reason == StatusCode.HARD_SPLIT.getValue()) {
+              if (reason == StatusCode.HARD_SPLIT.getValue() ||
+                reason == StatusCode.WORKER_SHUTDOWN.getValue()) {
                 logger.info(
                     "Push merged data return hard split for map "
                         + mapId
@@ -962,7 +1004,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                             mapId,
                             attemptId,
                             batches,
-                            StatusCode.HARD_SPLIT,
+                            Utils.toStatusCode(reason),
                             groupedBatchId));
               } else {
                 // Should not happen in current architecture.
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
index ea1de6af4..62832cc85 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
@@ -161,7 +161,7 @@ class ChangePartitionManager(
         // If new slot for the partition has been allocated, reply and return.
         // Else register and allocate for it.
         getLatestPartition(shuffleId, partitionId, oldEpoch).foreach { 
latestLoc =>
-          context.reply(StatusCode.SUCCESS, Some(latestLoc))
+          context.reply(partitionId, StatusCode.SUCCESS, Some(latestLoc))
           logDebug(s"New partition found, old partition $partitionId-$oldEpoch 
return it." +
             s" shuffleId: $shuffleId $latestLoc")
           return
@@ -204,6 +204,7 @@ class ChangePartitionManager(
     // Blacklist all failed workers
     if (changePartitions.exists(_.causes.isDefined)) {
       changePartitions.filter(_.causes.isDefined).foreach { changePartition =>
+        logInfo("cause is " + changePartition.causes.get)
         lifecycleManager.blacklistPartition(
           shuffleId,
           changePartition.oldPartition,
@@ -224,6 +225,7 @@ class ChangePartitionManager(
         }
       }.foreach { case (newLocation, requests) =>
         requests.map(_.asScala.toList.foreach(_.context.reply(
+          newLocation.getId,
           StatusCode.SUCCESS,
           Option(newLocation))))
       }
@@ -239,7 +241,8 @@ class ChangePartitionManager(
           Option(requestsMap.remove(changePartition.partitionId))
         }
       }.foreach { requests =>
-        requests.map(_.asScala.toList.foreach(_.context.reply(status, None)))
+        requests.map(_.asScala.toList.foreach(req =>
+          req.context.reply(req.partitionId, status, None)))
       }
     }
 
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 1f766fed6..5b1425316 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -22,15 +22,12 @@ import java.util
 import java.util.{function, List => JList}
 import java.util.concurrent.{Callable, ConcurrentHashMap, 
ScheduledExecutorService, ScheduledFuture, TimeUnit}
 import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, LongAdder}
-
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.util.Random
-
 import com.google.common.annotations.VisibleForTesting
 import com.google.common.cache.{Cache, CacheBuilder}
 import org.roaringbitmap.RoaringBitmap
-
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.haclient.RssHARetryClient
 import org.apache.celeborn.common.identity.{IdentityProvider, UserIdentifier}
@@ -439,10 +436,33 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
         shuffleId,
         mapId,
         attemptId,
-        partitionId,
-        epoch,
-        oldPartition,
-        cause)
+        Array(partitionId),
+        Array(epoch),
+        Array(oldPartition),
+        Array(cause))
+
+    case pb: PbReviveBatch =>
+      val applicationId = pb.getApplicationId
+      val shuffleId = pb.getShuffleId
+      val mapId = pb.getMapId
+      val attemptId = pb.getAttemptId
+      val ids = pb.getPartitionIdList.asScala.map(_.toInt).toArray
+      val epoches = pb.getEpochList.asScala.map(_.toInt).toArray
+      val oldPartitions = 
pb.getOldPartitionList.asScala.map(PbSerDeUtils.fromPbPartitionLocation(_)).toArray
+      val causes = pb.getStatusList.asScala.map(Utils.toStatusCode(_)).toArray
+      logTrace(s"Received ReviveBatch request, " +
+        s"$applicationId, $shuffleId, $mapId, $attemptId, 
,${ids.mkString(",")}," +
+        s" ${epoches.mkString(",")}, ${oldPartitions.mkString(",")}, 
${causes.mkString(",")}")
+      handleRevive(
+        context,
+        applicationId,
+        shuffleId,
+        mapId,
+        attemptId,
+        ids,
+        epoches,
+        oldPartitions,
+        causes)
 
     case pb: PbPartitionSplit =>
       val applicationId = pb.getApplicationId
@@ -684,12 +704,12 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       cause: StatusCode): Unit = {
     // only blacklist if cause is PushDataFailMain
     val failedWorker = new ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]()
-    if (cause == StatusCode.PUSH_DATA_FAIL_MASTER && oldPartition != null) {
+    if ((cause == StatusCode.PUSH_DATA_FAIL_MASTER || cause == 
StatusCode.WORKER_SHUTDOWN) && oldPartition != null) {
       val tmpWorker = oldPartition.getWorker
       val worker = workerSnapshots(shuffleId).keySet().asScala
         .find(_.equals(tmpWorker))
       if (worker.isDefined) {
-        failedWorker.put(worker.get, (StatusCode.PUSH_DATA_FAIL_MASTER, 
System.currentTimeMillis()))
+        failedWorker.put(worker.get, (cause, System.currentTimeMillis()))
       }
     }
     if (!failedWorker.isEmpty) {
@@ -703,14 +723,15 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       shuffleId: Int,
       mapId: Int,
       attemptId: Int,
-      partitionId: Int,
-      oldEpoch: Int,
-      oldPartition: PartitionLocation,
-      cause: StatusCode): Unit = {
+      partitionIds: Array[Int],
+      oldEpochs: Array[Int],
+      oldPartitions: Array[PartitionLocation],
+      causes: Array[StatusCode]): Unit = {
+    val contextWrapper = ChangeLocationsCallContext(context, 
partitionIds.length)
     // If shuffle not registered, reply ShuffleNotRegistered and return
     if (!registeredShuffle.contains(shuffleId)) {
       logError(s"[handleRevive] shuffle $shuffleId not registered!")
-      context.reply(ChangeLocationResponse(StatusCode.SHUFFLE_NOT_REGISTERED, 
None))
+      contextWrapper.reply(-1, StatusCode.SHUFFLE_NOT_REGISTERED, None)
       return
     }
 
@@ -719,21 +740,23 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       && shuffleMapperAttempts.get(shuffleId)(mapId) != -1) {
       logWarning(s"[handleRevive] Mapper ended, mapId $mapId, current 
attemptId $attemptId, " +
         s"ended attemptId ${shuffleMapperAttempts.get(shuffleId)(mapId)}, 
shuffleId $shuffleId.")
-      context.reply(ChangeLocationResponse(StatusCode.MAP_ENDED, None))
+      contextWrapper.reply(-1, StatusCode.MAP_ENDED, None)
       return
     }
 
-    logWarning(s"Do Revive for shuffle ${Utils.makeShuffleKey(applicationId, 
shuffleId)}, " +
-      s"oldPartition: $oldPartition, cause: $cause")
-
-    changePartitionManager.handleRequestPartitionLocation(
-      ChangeLocationCallContext(context),
-      applicationId,
-      shuffleId,
-      partitionId,
-      oldEpoch,
-      oldPartition,
-      Some(cause))
+    logDebug(s"Do Revive for shuffle ${Utils.makeShuffleKey(applicationId, 
shuffleId)}, " +
+      s"oldPartition: ${oldPartitions.mkString(",")}, cause: 
${causes.mkString(",")}")
+
+    0 until partitionIds.length foreach (idx => {
+      changePartitionManager.handleRequestPartitionLocation(
+        contextWrapper,
+        applicationId,
+        shuffleId,
+        partitionIds(idx),
+        oldEpochs(idx),
+        oldPartitions(idx),
+        Some(causes(idx)))
+    })
   }
 
   private def handleMapperEnd(
@@ -1252,7 +1275,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       }
     if (!destroyResource.isEmpty) {
       destroySlotsWithRetry(applicationId, shuffleId, destroyResource)
-      logInfo(s"Destroyed peer partitions for reserve buffer failed workers " +
+      logDebug(s"Destroyed peer partitions for reserve buffer failed workers " 
+
         s"${Utils.makeShuffleKey(applicationId, shuffleId)}, $destroyResource")
 
       val workerIds = new util.ArrayList[String]()
@@ -1264,7 +1287,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       }
       val msg = ReleaseSlots(applicationId, shuffleId, workerIds, 
workerSlotsPerDisk)
       requestReleaseSlots(rssHARetryClient, msg)
-      logInfo(s"Released slots for reserve buffer failed workers " +
+      logDebug(s"Released slots for reserve buffer failed workers " +
         s"${workerIds.asScala.mkString(",")}" + 
s"${slots.asScala.mkString(",")}" +
         s"${Utils.makeShuffleKey(applicationId, shuffleId)}, ")
     }
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala
 
b/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala
index 91cbf6df6..0aa43f7ca 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala
@@ -17,23 +17,50 @@
 
 package org.apache.celeborn.client
 
+import java.util.concurrent.ConcurrentHashMap
+
 import org.apache.celeborn.common.protocol.PartitionLocation
-import 
org.apache.celeborn.common.protocol.message.ControlMessages.{ChangeLocationResponse,
 RegisterShuffleResponse}
+import 
org.apache.celeborn.common.protocol.message.ControlMessages.{ChangeLocationResponse,
 ChangeLocationsResponse, RegisterShuffleResponse}
 import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.rpc.RpcCallContext
 
 trait RequestLocationCallContext {
-  def reply(status: StatusCode, partitionLocationOpt: 
Option[PartitionLocation]): Unit
+  def reply(id: Int, status: StatusCode, partitionLocationOpt: 
Option[PartitionLocation]): Unit
 }
 
 case class ChangeLocationCallContext(context: RpcCallContext) extends 
RequestLocationCallContext {
-  override def reply(status: StatusCode, partitionLocationOpt: 
Option[PartitionLocation]): Unit = {
+  override def reply(
+      id: Int,
+      status: StatusCode,
+      partitionLocationOpt: Option[PartitionLocation]): Unit = {
     context.reply(ChangeLocationResponse(status, partitionLocationOpt))
   }
 }
 
+case class ChangeLocationsCallContext(context: RpcCallContext, count: Int)
+  extends RequestLocationCallContext {
+  val locMap = new ConcurrentHashMap[Int, (StatusCode, 
PartitionLocation)](count)
+  override def reply(
+      id: Int,
+      status: StatusCode,
+      partitionLocationOpt: Option[PartitionLocation]): Unit = {
+    locMap.putIfAbsent(id, (status, partitionLocationOpt.getOrElse(new 
PartitionLocation())))
+    if (locMap.size() == count) {
+      locMap.synchronized {
+        if (locMap.size() == count || id == -1) {
+          context.reply(ChangeLocationsResponse(locMap))
+          locMap.clear()
+        }
+      }
+    }
+  }
+}
+
 case class ApplyNewLocationCallContext(context: RpcCallContext) extends 
RequestLocationCallContext {
-  override def reply(status: StatusCode, partitionLocationOpt: 
Option[PartitionLocation]): Unit = {
+  override def reply(
+      id: Int,
+      status: StatusCode,
+      partitionLocationOpt: Option[PartitionLocation]): Unit = {
     partitionLocationOpt match {
       case Some(partitionLocation) =>
         context.reply(RegisterShuffleResponse(status, 
Array(partitionLocation)))
diff --git 
a/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java
 
b/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java
index dccf87833..58be2afbc 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java
@@ -61,6 +61,8 @@ public class PartitionLocation implements Serializable {
   private StorageInfo storageInfo;
   private RoaringBitmap mapIdBitMap;
 
+  public PartitionLocation() {}
+
   public PartitionLocation(PartitionLocation loc) {
     this.id = loc.id;
     this.epoch = loc.epoch;
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index 0c1377dec..9e80b2e40 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -66,6 +66,8 @@ enum MessageType {
   STAGE_END = 45;
   STAGE_END_RESPONSE = 46;
   PARTITION_SPLIT = 47;
+  REVIVE_BATCH = 48;
+  CHANGE_LOCATIONS_RESPONSE = 49;
 }
 
 message PbStorageInfo {
@@ -214,11 +216,28 @@ message PbRevive {
   int32 status = 8;
 }
 
+message PbReviveBatch {
+  string applicationId = 1;
+  int32 shuffleId = 2;
+  int32 mapId = 3;
+  int32 attemptId = 4;
+  repeated int32 partitionId = 5;
+  repeated int32 epoch = 6;
+  repeated PbPartitionLocation oldPartition = 7;
+  repeated int32 status = 8;
+}
+
 message PbChangeLocationResponse {
   int32 status = 1;
   PbPartitionLocation location = 2;
 }
 
+message PbChangeLocationsResponse {
+  repeated int32 ids = 1;
+  repeated int32 status = 2;
+  repeated PbPartitionLocation location = 3;
+}
+
 message PbPartitionSplit {
   string applicationId = 1;
   int32 shuffleId = 2;
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 9697aabbc..49bb329cc 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -208,6 +208,34 @@ object ControlMessages extends Logging {
         .build()
   }
 
+  object ReviveBatch {
+    def apply(
+        appId: String,
+        shuffleId: Int,
+        mapId: Int,
+        attemptId: Int,
+        partitionId: Array[Int],
+        epoch: Array[Int],
+        oldPartition: Array[PartitionLocation],
+        cause: Array[StatusCode]): PbReviveBatch = {
+      val builder = PbReviveBatch.newBuilder()
+      builder
+        .setApplicationId(appId)
+        .setShuffleId(shuffleId)
+        .setMapId(mapId)
+        .setAttemptId(attemptId)
+
+      0 until partitionId.length foreach (idx => {
+        builder.addPartitionId(partitionId(idx))
+        builder.addEpoch(epoch(idx))
+        
builder.addOldPartition(PbSerDeUtils.toPbPartitionLocation(oldPartition(idx)))
+        builder.addStatus(cause(idx).getValue)
+      })
+
+      builder.build()
+    }
+  }
+
   object PartitionSplit {
     def apply(
         appId: String,
@@ -237,6 +265,21 @@ object ControlMessages extends Logging {
     }
   }
 
+  object ChangeLocationsResponse {
+    def apply(locMap: util.Map[Int, (StatusCode, PartitionLocation)]): 
PbChangeLocationsResponse = {
+      val builder = PbChangeLocationsResponse.newBuilder()
+      locMap.asScala.foreach(entry => {
+        val id = entry._1
+        val statusCode = entry._2._1
+        val loc = entry._2._2
+        builder.addIds(id)
+        builder.addStatus(statusCode.getValue)
+        builder.addLocation(PbSerDeUtils.toPbPartitionLocation(loc))
+      })
+      builder.build()
+    }
+  }
+
   case class MapperEnd(
       applicationId: String,
       shuffleId: Int,
@@ -518,9 +561,15 @@ object ControlMessages extends Logging {
     case pb: PbRevive =>
       new TransportMessage(MessageType.REVIVE, pb.toByteArray)
 
+    case pb: PbReviveBatch =>
+      new TransportMessage(MessageType.REVIVE_BATCH, pb.toByteArray)
+
     case pb: PbChangeLocationResponse =>
       new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, 
pb.toByteArray)
 
+    case pb: PbChangeLocationsResponse =>
+      new TransportMessage(MessageType.CHANGE_LOCATIONS_RESPONSE, 
pb.toByteArray)
+
     case MapperEnd(applicationId, shuffleId, mapId, attemptId, numMappers) =>
       val payload = PbMapperEnd.newBuilder()
         .setApplicationId(applicationId)
@@ -865,9 +914,15 @@ object ControlMessages extends Logging {
       case REVIVE =>
         PbRevive.parseFrom(message.getPayload)
 
+      case REVIVE_BATCH =>
+        PbReviveBatch.parseFrom(message.getPayload)
+
       case CHANGE_LOCATION_RESPONSE =>
         PbChangeLocationResponse.parseFrom(message.getPayload)
 
+      case CHANGE_LOCATIONS_RESPONSE =>
+        PbChangeLocationsResponse.parseFrom(message.getPayload)
+
       case MAPPER_END =>
         val pbMapperEnd = PbMapperEnd.parseFrom(message.getPayload)
         MapperEnd(
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index e90368799..22ba4a667 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -210,7 +210,7 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
     // This should before return exception to make current push data can 
revive and retry.
     if (shutdown.get()) {
       logInfo(s"Push data return HARD_SPLIT for shuffle $shuffleKey since 
worker shutdown.")
-      
callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+      
callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.WORKER_SHUTDOWN.getValue)))
       return
     }
 
@@ -395,7 +395,8 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
     // During worker shutdown, worker will return HARD_SPLIT for all existed 
partition.
     // This should before return exception to make current push data can 
revive and retry.
     if (shutdown.get()) {
-      
callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+      logInfo("shutting down, return WORKER_SHUTDOWN")
+      
callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.WORKER_SHUTDOWN.getValue)))
       return
     }
 

Reply via email to