This is an automated email from the ASF dual-hosted git repository. zhouky pushed a commit to branch test in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
commit 540912ed512321f852fd75e14887c36fa2081a6a Author: zky.zhoukeyong <[email protected]> AuthorDate: Fri Jun 9 19:43:39 2023 +0800 refine revive --- .../apache/celeborn/client/ShuffleClientImpl.java | 158 +++++++++++++-------- .../celeborn/client/ChangePartitionManager.scala | 7 +- .../apache/celeborn/client/LifecycleManager.scala | 79 +++++++---- .../client/RequestLocationCallContext.scala | 35 ++++- .../common/protocol/PartitionLocation.java | 2 + common/src/main/proto/TransportMessages.proto | 19 +++ .../common/protocol/message/ControlMessages.scala | 55 +++++++ .../service/deploy/worker/PushDataHandler.scala | 5 +- 8 files changed, 266 insertions(+), 94 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index e32dd1d8d..f2f3de2b3 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -175,7 +175,7 @@ public class ShuffleClientImpl extends ShuffleClient { pushState.removeBatch(batchId); } else { PartitionLocation newLoc = reducePartitionMap.get(shuffleId).get(partitionId); - logger.info("Revive success, new location for reduce {} is {}.", partitionId, newLoc); + logger.debug("Revive success, new location for reduce {} is {}.", partitionId, newLoc); try { TransportClient client = dataClientFactory.createClient(newLoc.getHost(), newLoc.getPushPort(), partitionId); @@ -209,27 +209,34 @@ public class ShuffleClientImpl extends ShuffleClient { StatusCode cause, Integer oldGroupedBatchId) { HashMap<String, DataBatches> newDataBatchesMap = new HashMap<>(); - for (DataBatches.DataBatch batch : batches) { - int partitionId = batch.loc.getId(); - if (!revive( - applicationId, - shuffleId, - mapId, - attemptId, - partitionId, - batch.loc.getEpoch(), - batch.loc, - cause)) { - pushState.exception.compareAndSet( - null, - new IOException("Revive Failed in retry push merged data for location: " + batch.loc)); - return; - } else if (mapperEnded(shuffleId, mapId, attemptId)) { - logger.debug( - "Retrying push data, but the mapper(map {} attempt {}) has ended.", mapId, attemptId); - } else { - PartitionLocation newLoc = reducePartitionMap.get(shuffleId).get(partitionId); - logger.info("Revive success, new location for reduce {} is {}.", partitionId, newLoc); + int[] ids = new int[batches.size()]; + int[] epoches = new int[batches.size()]; + PartitionLocation[] locs = new PartitionLocation[batches.size()]; + StatusCode[] causes = new StatusCode[batches.size()]; + + for (int i = 0; i < batches.size(); i++) { + DataBatches.DataBatch batch = batches.get(i); + ids[i] = batch.loc.getId(); + epoches[i] = batch.loc.getEpoch(); + locs[i] = batch.loc; + causes[i] = cause; + } + + boolean reviveSuccess = + reviveBatch(applicationId, shuffleId, mapId, attemptId, ids, epoches, locs, causes); + if (!reviveSuccess) { + pushState.exception.compareAndSet( + null, + new IOException( + "Revive Failed in retry push merged data for location: " + locs[0].getHost())); + return; + } else if (mapperEnded(shuffleId, mapId, attemptId)) { + logger.debug( + "Retrying push data, but the mapper(map {} attempt {}) has ended.", mapId, attemptId); + } else { + for (int i = 0; i < batches.size(); i++) { + DataBatches.DataBatch batch = batches.get(i); + PartitionLocation newLoc = reducePartitionMap.get(shuffleId).get(batch.loc.getId()); DataBatches newDataBatches = newDataBatchesMap.computeIfAbsent(genAddressPair(newLoc), (s) -> new DataBatches()); newDataBatches.addDataBatch(newLoc, batch.batchId, batch.body); @@ -395,20 +402,22 @@ public class ShuffleClientImpl extends ShuffleClient { PartitionLocation currentLocation = map.get(partitionId); if (currentLocation != null && currentLocation.getEpoch() > epoch) { return true; + } else { + return false; } - - long sleepTimeMs = RND.nextInt(50); - if (sleepTimeMs > 30) { - try { - TimeUnit.MILLISECONDS.sleep(sleepTimeMs); - } catch (InterruptedException e) { - logger.warn("Wait revived location interrupted", e); - Thread.currentThread().interrupt(); - } - } - - currentLocation = map.get(partitionId); - return currentLocation != null && currentLocation.getEpoch() > epoch; +// +// long sleepTimeMs = RND.nextInt(50); +// if (sleepTimeMs > 30) { +// try { +// TimeUnit.MILLISECONDS.sleep(sleepTimeMs); +// } catch (InterruptedException e) { +// logger.warn("Wait revived location interrupted", e); +// Thread.currentThread().interrupt(); +// } +// } + +// currentLocation = map.get(partitionId); +// return currentLocation != null && currentLocation.getEpoch() > epoch; } private boolean revive( @@ -431,6 +440,28 @@ public class ShuffleClientImpl extends ShuffleClient { epoch); return true; } + + return reviveBatch( + applicationId, + shuffleId, + mapId, + attemptId, + new int[] {partitionId}, + new int[] {epoch}, + new PartitionLocation[] {oldLocation}, + new StatusCode[] {cause}); + } + + private boolean reviveBatch( + String applicationId, + int shuffleId, + int mapId, + int attemptId, + int[] partitionIds, + int[] epoches, + PartitionLocation[] oldLocations, + StatusCode[] causes) { + ConcurrentHashMap<Integer, PartitionLocation> map = reducePartitionMap.get(shuffleId); String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); if (mapperEnded(shuffleId, mapId, attemptId)) { logger.debug( @@ -442,36 +473,45 @@ public class ShuffleClientImpl extends ShuffleClient { } try { - PbChangeLocationResponse response = + PbChangeLocationsResponse response = driverRssMetaService.askSync( - Revive$.MODULE$.apply( + ReviveBatch$.MODULE$.apply( applicationId, shuffleId, mapId, attemptId, - partitionId, - epoch, - oldLocation, - cause), + partitionIds, + epoches, + oldLocations, + causes), conf.requestPartitionLocationRpcAskTimeout(), - ClassTag$.MODULE$.apply(PbChangeLocationResponse.class)); - // per partitionKey only serve single PartitionLocation in Client Cache. - StatusCode respStatus = Utils.toStatusCode(response.getStatus()); - if (StatusCode.SUCCESS.equals(respStatus)) { - map.put(partitionId, PbSerDeUtils.fromPbPartitionLocation(response.getLocation())); - return true; - } else if (StatusCode.MAP_ENDED.equals(respStatus)) { - mapperEndMap.computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet()).add(mapKey); - return true; - } else { - return false; + ClassTag$.MODULE$.apply(PbChangeLocationsResponse.class)); + boolean allSuccess = true; + for (int i = 0; i < response.getIdsCount(); i++) { + if (StatusCode.SUCCESS.equals(Utils.toStatusCode(response.getStatus(i)))) { + map.put( + response.getIds(i), PbSerDeUtils.fromPbPartitionLocation(response.getLocation(i))); + } else { + logger.info("StatusCode not SUCCESS! {}", response.getStatus(i)); + allSuccess = false; + } } + if (!allSuccess) { + StatusCode firstStatus = Utils.toStatusCode(response.getStatus(0)); + if (StatusCode.MAP_ENDED.equals(firstStatus)) { + mapperEndMap + .computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet()) + .add(mapKey); + return true; + } + } + return allSuccess; } catch (Exception e) { logger.error( "Exception raised while reviving for shuffle {} reduce {} epoch {}.", shuffleId, - partitionId, - epoch, + partitionIds, + epoches, e); return false; } @@ -645,7 +685,8 @@ public class ShuffleClientImpl extends ShuffleClient { nextBatchId); splitPartition(shuffleId, partitionId, applicationId, loc); callback.onSuccess(response); - } else if (reason == StatusCode.HARD_SPLIT.getValue()) { + } else if (reason == StatusCode.HARD_SPLIT.getValue() || + reason == StatusCode.WORKER_SHUTDOWN.getValue()) { logger.debug( "Push data split for map {} attempt {} batch {}.", mapId, @@ -663,7 +704,7 @@ public class ShuffleClientImpl extends ShuffleClient { loc, this, pushState, - StatusCode.HARD_SPLIT)); + Utils.toStatusCode(reason))); } else { response.rewind(); callback.onSuccess(response); @@ -944,7 +985,8 @@ public class ShuffleClientImpl extends ShuffleClient { public void onSuccess(ByteBuffer response) { if (response.remaining() > 0) { byte reason = response.get(); - if (reason == StatusCode.HARD_SPLIT.getValue()) { + if (reason == StatusCode.HARD_SPLIT.getValue() || + reason == StatusCode.WORKER_SHUTDOWN.getValue()) { logger.info( "Push merged data return hard split for map " + mapId @@ -962,7 +1004,7 @@ public class ShuffleClientImpl extends ShuffleClient { mapId, attemptId, batches, - StatusCode.HARD_SPLIT, + Utils.toStatusCode(reason), groupedBatchId)); } else { // Should not happen in current architecture. diff --git a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala index ea1de6af4..62832cc85 100644 --- a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala @@ -161,7 +161,7 @@ class ChangePartitionManager( // If new slot for the partition has been allocated, reply and return. // Else register and allocate for it. getLatestPartition(shuffleId, partitionId, oldEpoch).foreach { latestLoc => - context.reply(StatusCode.SUCCESS, Some(latestLoc)) + context.reply(partitionId, StatusCode.SUCCESS, Some(latestLoc)) logDebug(s"New partition found, old partition $partitionId-$oldEpoch return it." + s" shuffleId: $shuffleId $latestLoc") return @@ -204,6 +204,7 @@ class ChangePartitionManager( // Blacklist all failed workers if (changePartitions.exists(_.causes.isDefined)) { changePartitions.filter(_.causes.isDefined).foreach { changePartition => + logInfo("cause is " + changePartition.causes.get) lifecycleManager.blacklistPartition( shuffleId, changePartition.oldPartition, @@ -224,6 +225,7 @@ class ChangePartitionManager( } }.foreach { case (newLocation, requests) => requests.map(_.asScala.toList.foreach(_.context.reply( + newLocation.getId, StatusCode.SUCCESS, Option(newLocation)))) } @@ -239,7 +241,8 @@ class ChangePartitionManager( Option(requestsMap.remove(changePartition.partitionId)) } }.foreach { requests => - requests.map(_.asScala.toList.foreach(_.context.reply(status, None))) + requests.map(_.asScala.toList.foreach(req => + req.context.reply(req.partitionId, status, None))) } } diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 1f766fed6..5b1425316 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -22,15 +22,12 @@ import java.util import java.util.{function, List => JList} import java.util.concurrent.{Callable, ConcurrentHashMap, ScheduledExecutorService, ScheduledFuture, TimeUnit} import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, LongAdder} - import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random - import com.google.common.annotations.VisibleForTesting import com.google.common.cache.{Cache, CacheBuilder} import org.roaringbitmap.RoaringBitmap - import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.haclient.RssHARetryClient import org.apache.celeborn.common.identity.{IdentityProvider, UserIdentifier} @@ -439,10 +436,33 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin shuffleId, mapId, attemptId, - partitionId, - epoch, - oldPartition, - cause) + Array(partitionId), + Array(epoch), + Array(oldPartition), + Array(cause)) + + case pb: PbReviveBatch => + val applicationId = pb.getApplicationId + val shuffleId = pb.getShuffleId + val mapId = pb.getMapId + val attemptId = pb.getAttemptId + val ids = pb.getPartitionIdList.asScala.map(_.toInt).toArray + val epoches = pb.getEpochList.asScala.map(_.toInt).toArray + val oldPartitions = pb.getOldPartitionList.asScala.map(PbSerDeUtils.fromPbPartitionLocation(_)).toArray + val causes = pb.getStatusList.asScala.map(Utils.toStatusCode(_)).toArray + logTrace(s"Received ReviveBatch request, " + + s"$applicationId, $shuffleId, $mapId, $attemptId, ,${ids.mkString(",")}," + + s" ${epoches.mkString(",")}, ${oldPartitions.mkString(",")}, ${causes.mkString(",")}") + handleRevive( + context, + applicationId, + shuffleId, + mapId, + attemptId, + ids, + epoches, + oldPartitions, + causes) case pb: PbPartitionSplit => val applicationId = pb.getApplicationId @@ -684,12 +704,12 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin cause: StatusCode): Unit = { // only blacklist if cause is PushDataFailMain val failedWorker = new ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]() - if (cause == StatusCode.PUSH_DATA_FAIL_MASTER && oldPartition != null) { + if ((cause == StatusCode.PUSH_DATA_FAIL_MASTER || cause == StatusCode.WORKER_SHUTDOWN) && oldPartition != null) { val tmpWorker = oldPartition.getWorker val worker = workerSnapshots(shuffleId).keySet().asScala .find(_.equals(tmpWorker)) if (worker.isDefined) { - failedWorker.put(worker.get, (StatusCode.PUSH_DATA_FAIL_MASTER, System.currentTimeMillis())) + failedWorker.put(worker.get, (cause, System.currentTimeMillis())) } } if (!failedWorker.isEmpty) { @@ -703,14 +723,15 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin shuffleId: Int, mapId: Int, attemptId: Int, - partitionId: Int, - oldEpoch: Int, - oldPartition: PartitionLocation, - cause: StatusCode): Unit = { + partitionIds: Array[Int], + oldEpochs: Array[Int], + oldPartitions: Array[PartitionLocation], + causes: Array[StatusCode]): Unit = { + val contextWrapper = ChangeLocationsCallContext(context, partitionIds.length) // If shuffle not registered, reply ShuffleNotRegistered and return if (!registeredShuffle.contains(shuffleId)) { logError(s"[handleRevive] shuffle $shuffleId not registered!") - context.reply(ChangeLocationResponse(StatusCode.SHUFFLE_NOT_REGISTERED, None)) + contextWrapper.reply(-1, StatusCode.SHUFFLE_NOT_REGISTERED, None) return } @@ -719,21 +740,23 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin && shuffleMapperAttempts.get(shuffleId)(mapId) != -1) { logWarning(s"[handleRevive] Mapper ended, mapId $mapId, current attemptId $attemptId, " + s"ended attemptId ${shuffleMapperAttempts.get(shuffleId)(mapId)}, shuffleId $shuffleId.") - context.reply(ChangeLocationResponse(StatusCode.MAP_ENDED, None)) + contextWrapper.reply(-1, StatusCode.MAP_ENDED, None) return } - logWarning(s"Do Revive for shuffle ${Utils.makeShuffleKey(applicationId, shuffleId)}, " + - s"oldPartition: $oldPartition, cause: $cause") - - changePartitionManager.handleRequestPartitionLocation( - ChangeLocationCallContext(context), - applicationId, - shuffleId, - partitionId, - oldEpoch, - oldPartition, - Some(cause)) + logDebug(s"Do Revive for shuffle ${Utils.makeShuffleKey(applicationId, shuffleId)}, " + + s"oldPartition: ${oldPartitions.mkString(",")}, cause: ${causes.mkString(",")}") + + 0 until partitionIds.length foreach (idx => { + changePartitionManager.handleRequestPartitionLocation( + contextWrapper, + applicationId, + shuffleId, + partitionIds(idx), + oldEpochs(idx), + oldPartitions(idx), + Some(causes(idx))) + }) } private def handleMapperEnd( @@ -1252,7 +1275,7 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin } if (!destroyResource.isEmpty) { destroySlotsWithRetry(applicationId, shuffleId, destroyResource) - logInfo(s"Destroyed peer partitions for reserve buffer failed workers " + + logDebug(s"Destroyed peer partitions for reserve buffer failed workers " + s"${Utils.makeShuffleKey(applicationId, shuffleId)}, $destroyResource") val workerIds = new util.ArrayList[String]() @@ -1264,7 +1287,7 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin } val msg = ReleaseSlots(applicationId, shuffleId, workerIds, workerSlotsPerDisk) requestReleaseSlots(rssHARetryClient, msg) - logInfo(s"Released slots for reserve buffer failed workers " + + logDebug(s"Released slots for reserve buffer failed workers " + s"${workerIds.asScala.mkString(",")}" + s"${slots.asScala.mkString(",")}" + s"${Utils.makeShuffleKey(applicationId, shuffleId)}, ") } diff --git a/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala b/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala index 91cbf6df6..0aa43f7ca 100644 --- a/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala +++ b/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala @@ -17,23 +17,50 @@ package org.apache.celeborn.client +import java.util.concurrent.ConcurrentHashMap + import org.apache.celeborn.common.protocol.PartitionLocation -import org.apache.celeborn.common.protocol.message.ControlMessages.{ChangeLocationResponse, RegisterShuffleResponse} +import org.apache.celeborn.common.protocol.message.ControlMessages.{ChangeLocationResponse, ChangeLocationsResponse, RegisterShuffleResponse} import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.rpc.RpcCallContext trait RequestLocationCallContext { - def reply(status: StatusCode, partitionLocationOpt: Option[PartitionLocation]): Unit + def reply(id: Int, status: StatusCode, partitionLocationOpt: Option[PartitionLocation]): Unit } case class ChangeLocationCallContext(context: RpcCallContext) extends RequestLocationCallContext { - override def reply(status: StatusCode, partitionLocationOpt: Option[PartitionLocation]): Unit = { + override def reply( + id: Int, + status: StatusCode, + partitionLocationOpt: Option[PartitionLocation]): Unit = { context.reply(ChangeLocationResponse(status, partitionLocationOpt)) } } +case class ChangeLocationsCallContext(context: RpcCallContext, count: Int) + extends RequestLocationCallContext { + val locMap = new ConcurrentHashMap[Int, (StatusCode, PartitionLocation)](count) + override def reply( + id: Int, + status: StatusCode, + partitionLocationOpt: Option[PartitionLocation]): Unit = { + locMap.putIfAbsent(id, (status, partitionLocationOpt.getOrElse(new PartitionLocation()))) + if (locMap.size() == count) { + locMap.synchronized { + if (locMap.size() == count || id == -1) { + context.reply(ChangeLocationsResponse(locMap)) + locMap.clear() + } + } + } + } +} + case class ApplyNewLocationCallContext(context: RpcCallContext) extends RequestLocationCallContext { - override def reply(status: StatusCode, partitionLocationOpt: Option[PartitionLocation]): Unit = { + override def reply( + id: Int, + status: StatusCode, + partitionLocationOpt: Option[PartitionLocation]): Unit = { partitionLocationOpt match { case Some(partitionLocation) => context.reply(RegisterShuffleResponse(status, Array(partitionLocation))) diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java b/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java index dccf87833..58be2afbc 100644 --- a/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java +++ b/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java @@ -61,6 +61,8 @@ public class PartitionLocation implements Serializable { private StorageInfo storageInfo; private RoaringBitmap mapIdBitMap; + public PartitionLocation() {} + public PartitionLocation(PartitionLocation loc) { this.id = loc.id; this.epoch = loc.epoch; diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index 0c1377dec..9e80b2e40 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -66,6 +66,8 @@ enum MessageType { STAGE_END = 45; STAGE_END_RESPONSE = 46; PARTITION_SPLIT = 47; + REVIVE_BATCH = 48; + CHANGE_LOCATIONS_RESPONSE = 49; } message PbStorageInfo { @@ -214,11 +216,28 @@ message PbRevive { int32 status = 8; } +message PbReviveBatch { + string applicationId = 1; + int32 shuffleId = 2; + int32 mapId = 3; + int32 attemptId = 4; + repeated int32 partitionId = 5; + repeated int32 epoch = 6; + repeated PbPartitionLocation oldPartition = 7; + repeated int32 status = 8; +} + message PbChangeLocationResponse { int32 status = 1; PbPartitionLocation location = 2; } +message PbChangeLocationsResponse { + repeated int32 ids = 1; + repeated int32 status = 2; + repeated PbPartitionLocation location = 3; +} + message PbPartitionSplit { string applicationId = 1; int32 shuffleId = 2; diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 9697aabbc..49bb329cc 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -208,6 +208,34 @@ object ControlMessages extends Logging { .build() } + object ReviveBatch { + def apply( + appId: String, + shuffleId: Int, + mapId: Int, + attemptId: Int, + partitionId: Array[Int], + epoch: Array[Int], + oldPartition: Array[PartitionLocation], + cause: Array[StatusCode]): PbReviveBatch = { + val builder = PbReviveBatch.newBuilder() + builder + .setApplicationId(appId) + .setShuffleId(shuffleId) + .setMapId(mapId) + .setAttemptId(attemptId) + + 0 until partitionId.length foreach (idx => { + builder.addPartitionId(partitionId(idx)) + builder.addEpoch(epoch(idx)) + builder.addOldPartition(PbSerDeUtils.toPbPartitionLocation(oldPartition(idx))) + builder.addStatus(cause(idx).getValue) + }) + + builder.build() + } + } + object PartitionSplit { def apply( appId: String, @@ -237,6 +265,21 @@ object ControlMessages extends Logging { } } + object ChangeLocationsResponse { + def apply(locMap: util.Map[Int, (StatusCode, PartitionLocation)]): PbChangeLocationsResponse = { + val builder = PbChangeLocationsResponse.newBuilder() + locMap.asScala.foreach(entry => { + val id = entry._1 + val statusCode = entry._2._1 + val loc = entry._2._2 + builder.addIds(id) + builder.addStatus(statusCode.getValue) + builder.addLocation(PbSerDeUtils.toPbPartitionLocation(loc)) + }) + builder.build() + } + } + case class MapperEnd( applicationId: String, shuffleId: Int, @@ -518,9 +561,15 @@ object ControlMessages extends Logging { case pb: PbRevive => new TransportMessage(MessageType.REVIVE, pb.toByteArray) + case pb: PbReviveBatch => + new TransportMessage(MessageType.REVIVE_BATCH, pb.toByteArray) + case pb: PbChangeLocationResponse => new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, pb.toByteArray) + case pb: PbChangeLocationsResponse => + new TransportMessage(MessageType.CHANGE_LOCATIONS_RESPONSE, pb.toByteArray) + case MapperEnd(applicationId, shuffleId, mapId, attemptId, numMappers) => val payload = PbMapperEnd.newBuilder() .setApplicationId(applicationId) @@ -865,9 +914,15 @@ object ControlMessages extends Logging { case REVIVE => PbRevive.parseFrom(message.getPayload) + case REVIVE_BATCH => + PbReviveBatch.parseFrom(message.getPayload) + case CHANGE_LOCATION_RESPONSE => PbChangeLocationResponse.parseFrom(message.getPayload) + case CHANGE_LOCATIONS_RESPONSE => + PbChangeLocationsResponse.parseFrom(message.getPayload) + case MAPPER_END => val pbMapperEnd = PbMapperEnd.parseFrom(message.getPayload) MapperEnd( diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala index e90368799..22ba4a667 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala @@ -210,7 +210,7 @@ class PushDataHandler extends BaseMessageHandler with Logging { // This should before return exception to make current push data can revive and retry. if (shutdown.get()) { logInfo(s"Push data return HARD_SPLIT for shuffle $shuffleKey since worker shutdown.") - callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue))) + callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.WORKER_SHUTDOWN.getValue))) return } @@ -395,7 +395,8 @@ class PushDataHandler extends BaseMessageHandler with Logging { // During worker shutdown, worker will return HARD_SPLIT for all existed partition. // This should before return exception to make current push data can revive and retry. if (shutdown.get()) { - callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue))) + logInfo("shutting down, return WORKER_SHUTDOWN") + callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.WORKER_SHUTDOWN.getValue))) return }
