This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 6629be858 [CELEBORN-1574] Speed up unregister shuffle by batch 
processing
6629be858 is described below

commit 6629be858bf5eecdab7930aa361f67743fbeb2f7
Author: szt <[email protected]>
AuthorDate: Tue Oct 8 14:03:41 2024 +0800

    [CELEBORN-1574] Speed up unregister shuffle by batch processing
    
    ### What changes were proposed in this pull request?
    In order to speed up the resource releasing,this PR Unregister shuffle in 
batch;
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT & Local cluster testing
    
    Closes #2701 from zaynt4606/batchUnregister.
    
    Lead-authored-by: szt <[email protected]>
    Co-authored-by: Zaynt <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../apache/celeborn/client/LifecycleManager.scala  |  42 ++++++-
 common/src/main/proto/TransportMessages.proto      |  13 +++
 .../org/apache/celeborn/common/CelebornConf.scala  |  10 ++
 .../common/protocol/message/ControlMessages.scala  |  35 ++++++
 docs/configuration/client.md                       |   1 +
 .../master/clustermeta/AbstractMetaManager.java    |   4 +
 .../master/clustermeta/IMetadataHandler.java       |   2 +
 .../clustermeta/SingleMasterMetaManager.java       |   5 +
 .../master/clustermeta/ha/HAMasterMetaManager.java |  18 +++
 .../deploy/master/clustermeta/ha/MetaHandler.java  |   7 ++
 master/src/main/proto/Resource.proto               |   6 +
 .../celeborn/service/deploy/master/Master.scala    |  22 ++++
 .../clustermeta/DefaultMetaSystemSuiteJ.java       |  82 ++++++++++++++
 .../ha/RatisMasterStatusSystemSuiteJ.java          |  96 ++++++++++++++++
 .../LifecycleManagerUnregisterShuffleSuite.scala   | 121 +++++++++++++++++++++
 15 files changed, 460 insertions(+), 4 deletions(-)

diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index f1324f9cb..60721c160 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -29,6 +29,7 @@ import java.util.function.Consumer
 import scala.collection.JavaConverters._
 import scala.collection.generic.CanBuildFrom
 import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.{ExecutionContext, Future}
 import scala.concurrent.duration.Duration
 import scala.util.Random
@@ -103,6 +104,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
   private val rpcCacheExpireTime = conf.clientRpcCacheExpireTime
   private val rpcMaxRetires = conf.clientRpcMaxRetries
 
+  private val batchRemoveExpiredShufflesEnabled = 
conf.batchHandleRemoveExpiredShufflesEnabled
+
   private val excludedWorkersFilter = 
conf.registerShuffleFilterExcludedWorkerEnabled
 
   private val registerShuffleResponseRpcCache: Cache[Int, ByteBuffer] = 
CacheBuilder.newBuilder()
@@ -1579,6 +1582,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
 
   private def removeExpiredShuffle(): Unit = {
     val currentTime = System.currentTimeMillis()
+    val batchRemoveShuffleIds = new ArrayBuffer[Integer]
     unregisterShuffleTime.keys().asScala.foreach { shuffleId =>
       if (unregisterShuffleTime.get(shuffleId) < currentTime - 
shuffleExpiredCheckIntervalMs) {
         logInfo(s"Clear shuffle $shuffleId.")
@@ -1589,10 +1593,26 @@ class LifecycleManager(val appUniqueId: String, val 
conf: CelebornConf) extends
         latestPartitionLocation.remove(shuffleId)
         commitManager.removeExpiredShuffle(shuffleId)
         changePartitionManager.removeExpiredShuffle(shuffleId)
-        val unregisterShuffleResponse = requestMasterUnregisterShuffle(
-          UnregisterShuffle(appUniqueId, shuffleId, 
MasterClient.genRequestId()))
-        // if unregister shuffle not success, wait next turn
-        if (StatusCode.SUCCESS == 
Utils.toStatusCode(unregisterShuffleResponse.getStatus)) {
+        if (!batchRemoveExpiredShufflesEnabled) {
+          val unregisterShuffleResponse = requestMasterUnregisterShuffle(
+            UnregisterShuffle(appUniqueId, shuffleId, 
MasterClient.genRequestId()))
+          // if unregister shuffle not success, wait next turn
+          if (StatusCode.SUCCESS == 
Utils.toStatusCode(unregisterShuffleResponse.getStatus)) {
+            unregisterShuffleTime.remove(shuffleId)
+          }
+        } else {
+          batchRemoveShuffleIds += shuffleId
+        }
+      }
+    }
+    if (batchRemoveShuffleIds.nonEmpty) {
+      val unregisterShuffleResponse = batchRequestMasterUnregisterShuffles(
+        BatchUnregisterShuffles(
+          appUniqueId,
+          batchRemoveShuffleIds.asJava,
+          MasterClient.genRequestId()))
+      if (StatusCode.SUCCESS == 
Utils.toStatusCode(unregisterShuffleResponse.getStatus)) {
+        batchRemoveShuffleIds.foreach { shuffleId: Integer =>
           unregisterShuffleTime.remove(shuffleId)
         }
       }
@@ -1671,6 +1691,20 @@ class LifecycleManager(val appUniqueId: String, val 
conf: CelebornConf) extends
     }
   }
 
+  private def batchRequestMasterUnregisterShuffles(message: 
PbBatchUnregisterShuffles)
+      : PbBatchUnregisterShuffleResponse = {
+    try {
+      logInfo(s"AskSync BatchUnregisterShuffle for 
${message.getShuffleIdsList}")
+      masterClient.askSync[PbBatchUnregisterShuffleResponse](
+        message,
+        classOf[PbBatchUnregisterShuffleResponse])
+    } catch {
+      case e: Exception =>
+        logError(s"AskSync BatchUnregisterShuffle for 
${message.getShuffleIdsList} failed.", e)
+        BatchUnregisterShuffleResponse(StatusCode.REQUEST_FAILED)
+    }
+  }
+
   def checkQuota(): CheckQuotaResponse = {
     try {
       masterClient.askSync[CheckQuotaResponse](
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index 9bde109f0..4dc4f1489 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -107,6 +107,8 @@ enum MessageType {
   REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_RESPONSE = 84;
   SEGMENT_START = 85;
   NOTIFY_REQUIRED_SEGMENT = 86;
+  BATCH_UNREGISTER_SHUFFLES = 87;
+  BATCH_UNREGISTER_SHUFFLE_RESPONSE= 88;
 }
 
 enum StreamType {
@@ -406,10 +408,21 @@ message PbUnregisterShuffle {
   string requestId = 3;
 }
 
+message PbBatchUnregisterShuffles {
+  string appId = 1;
+  string requestId = 2;
+  repeated int32 shuffleIds = 3;
+}
+
 message PbUnregisterShuffleResponse {
   int32 status = 1;
 }
 
+message PbBatchUnregisterShuffleResponse {
+  int32 status = 1;
+  repeated int32 shuffleIds = 2;
+}
+
 message PbApplicationLost {
   string appId = 1;
   string requestId = 2;
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 81ae7dc2b..290ce60f9 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -1042,6 +1042,8 @@ class CelebornConf(loadDefaults: Boolean) extends 
Cloneable with Logging with Se
   def batchHandleChangePartitionNumThreads: Int = 
get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_THREADS)
   def batchHandleChangePartitionRequestInterval: Long =
     get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_INTERVAL)
+  def batchHandleRemoveExpiredShufflesEnabled: Boolean =
+    get(CLIENT_BATCH_REMOVE_EXPIRED_SHUFFLE_ENABLED)
   def batchHandleCommitPartitionEnabled: Boolean = 
get(CLIENT_BATCH_HANDLE_COMMIT_PARTITION_ENABLED)
   def batchHandleCommitPartitionNumThreads: Int = 
get(CLIENT_BATCH_HANDLE_COMMIT_PARTITION_THREADS)
   def batchHandleCommitPartitionRequestInterval: Long =
@@ -4592,6 +4594,14 @@ object CelebornConf extends Logging {
       .booleanConf
       .createWithDefault(true)
 
+  val CLIENT_BATCH_REMOVE_EXPIRED_SHUFFLE_ENABLED: ConfigEntry[Boolean] =
+    
buildConf("celeborn.client.shuffle.batchHandleRemoveExpiredShuffles.enabled")
+      .categories("client")
+      .version("0.6.0")
+      .doc("Whether to batch remove expired shuffles. This is an optimization 
switch on removing expired shuffles.")
+      .booleanConf
+      .createWithDefault(false)
+
   val CLIENT_BATCH_HANDLE_CHANGE_PARTITION_BUCKETS: ConfigEntry[Int] =
     
buildConf("celeborn.client.shuffle.batchHandleChangePartition.partitionBuckets")
       .categories("client")
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index e68212baf..5c345237c 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -351,6 +351,18 @@ object ControlMessages extends Logging {
         .build()
   }
 
+  object BatchUnregisterShuffles {
+    def apply(
+        appId: String,
+        shuffleIds: util.List[Integer],
+        requestId: String): PbBatchUnregisterShuffles =
+      PbBatchUnregisterShuffles.newBuilder()
+        .setAppId(appId)
+        .addAllShuffleIds(shuffleIds)
+        .setRequestId(requestId)
+        .build()
+  }
+
   object UnregisterShuffleResponse {
     def apply(status: StatusCode): PbUnregisterShuffleResponse =
       PbUnregisterShuffleResponse.newBuilder()
@@ -358,6 +370,17 @@ object ControlMessages extends Logging {
         .build()
   }
 
+  object BatchUnregisterShuffleResponse {
+    def apply(
+        status: StatusCode,
+        shuffleIds: util.List[Integer] = Collections.emptyList())
+        : PbBatchUnregisterShuffleResponse =
+      PbBatchUnregisterShuffleResponse.newBuilder()
+        .setStatus(status.getValue)
+        .addAllShuffleIds(shuffleIds)
+        .build()
+  }
+
   case class ApplicationLost(
       appId: String,
       override var requestId: String = ZERO_UUID) extends MasterRequestMessage
@@ -731,9 +754,15 @@ object ControlMessages extends Logging {
     case pb: PbUnregisterShuffle =>
       new TransportMessage(MessageType.UNREGISTER_SHUFFLE, pb.toByteArray)
 
+    case pb: PbBatchUnregisterShuffles =>
+      new TransportMessage(MessageType.BATCH_UNREGISTER_SHUFFLES, 
pb.toByteArray)
+
     case pb: PbUnregisterShuffleResponse =>
       new TransportMessage(MessageType.UNREGISTER_SHUFFLE_RESPONSE, 
pb.toByteArray)
 
+    case pb: PbBatchUnregisterShuffleResponse =>
+      new TransportMessage(MessageType.BATCH_UNREGISTER_SHUFFLE_RESPONSE, 
pb.toByteArray)
+
     case ApplicationLost(appId, requestId) =>
       val payload = PbApplicationLost.newBuilder()
         .setAppId(appId).setRequestId(requestId)
@@ -1119,9 +1148,15 @@ object ControlMessages extends Logging {
       case UNREGISTER_SHUFFLE_VALUE =>
         PbUnregisterShuffle.parseFrom(message.getPayload)
 
+      case BATCH_UNREGISTER_SHUFFLES_VALUE =>
+        PbBatchUnregisterShuffles.parseFrom(message.getPayload)
+
       case UNREGISTER_SHUFFLE_RESPONSE_VALUE =>
         PbUnregisterShuffleResponse.parseFrom(message.getPayload)
 
+      case BATCH_UNREGISTER_SHUFFLE_RESPONSE_VALUE =>
+        PbBatchUnregisterShuffleResponse.parseFrom(message.getPayload)
+
       case APPLICATION_LOST_VALUE =>
         val pbApplicationLost = PbApplicationLost.parseFrom(message.getPayload)
         ApplicationLost(pbApplicationLost.getAppId, 
pbApplicationLost.getRequestId)
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index bcae9e40f..7e7068e17 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -90,6 +90,7 @@ license: |
 | celeborn.client.shuffle.batchHandleCommitPartition.threads | 8 | false | 
Threads number for LifecycleManager to handle commit partition request in 
batch. | 0.3.0 | celeborn.shuffle.batchHandleCommitPartition.threads | 
 | celeborn.client.shuffle.batchHandleReleasePartition.interval | 5s | false | 
Interval for LifecycleManager to schedule handling release partition requests 
in batch. | 0.3.0 |  | 
 | celeborn.client.shuffle.batchHandleReleasePartition.threads | 8 | false | 
Threads number for LifecycleManager to handle release partition request in 
batch. | 0.3.0 |  | 
+| celeborn.client.shuffle.batchHandleRemoveExpiredShuffles.enabled | false | 
false | Whether to batch remove expired shuffles. This is an optimization 
switch on removing expired shuffles. | 0.6.0 |  | 
 | celeborn.client.shuffle.compression.codec | LZ4 | false | The codec used to 
compress shuffle data. By default, Celeborn provides three codecs: `lz4`, 
`zstd`, `none`. `none` means that shuffle compression is disabled. Since Flink 
version 1.17, zstd is supported for Flink shuffle client. | 0.3.0 | 
celeborn.shuffle.compression.codec,remote-shuffle.job.compression.codec | 
 | celeborn.client.shuffle.compression.zstd.level | 1 | false | Compression 
level for Zstd compression codec, its value should be an integer between -5 and 
22. Increasing the compression level will result in better compression at the 
expense of more CPU and memory. | 0.3.0 | 
celeborn.shuffle.compression.zstd.level | 
 | celeborn.client.shuffle.decompression.lz4.xxhash.instance | 
&lt;undefined&gt; | false | Decompression XXHash instance for Lz4. Available 
options: JNI, JAVASAFE, JAVAUNSAFE. | 0.3.2 |  | 
diff --git 
a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/AbstractMetaManager.java
 
b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/AbstractMetaManager.java
index 53f442091..1631e90d3 100644
--- 
a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/AbstractMetaManager.java
+++ 
b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/AbstractMetaManager.java
@@ -114,6 +114,10 @@ public abstract class AbstractMetaManager implements 
IMetadataHandler {
     registeredShuffle.remove(shuffleKey);
   }
 
+  public void updateBatchUnregisterShuffleMeta(List<String> shuffleKeys) {
+    registeredShuffle.removeAll(shuffleKeys);
+  }
+
   public void updateAppHeartbeatMeta(String appId, long time, long 
totalWritten, long fileCount) {
     appHeartbeatTime.put(appId, time);
     partitionTotalWritten.add(totalWritten);
diff --git 
a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/IMetadataHandler.java
 
b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/IMetadataHandler.java
index e9dcb3191..7eb72679e 100644
--- 
a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/IMetadataHandler.java
+++ 
b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/IMetadataHandler.java
@@ -36,6 +36,8 @@ public interface IMetadataHandler {
 
   void handleUnRegisterShuffle(String shuffleKey, String requestId);
 
+  void handleBatchUnRegisterShuffles(List<String> shuffleKeys, String 
requestId);
+
   void handleAppHeartbeat(
       String appId, long totalWritten, long fileCount, long time, String 
requestId);
 
diff --git 
a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/SingleMasterMetaManager.java
 
b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/SingleMasterMetaManager.java
index 2adde50b6..b311b1b45 100644
--- 
a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/SingleMasterMetaManager.java
+++ 
b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/SingleMasterMetaManager.java
@@ -65,6 +65,11 @@ public class SingleMasterMetaManager extends 
AbstractMetaManager {
     updateUnregisterShuffleMeta(shuffleKey);
   }
 
+  @Override
+  public void handleBatchUnRegisterShuffles(List<String> shuffleKeys, String 
requestId) {
+    updateBatchUnregisterShuffleMeta(shuffleKeys);
+  }
+
   @Override
   public void handleAppHeartbeat(
       String appId, long totalWritten, long fileCount, long time, String 
requestId) {
diff --git 
a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAMasterMetaManager.java
 
b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAMasterMetaManager.java
index 738fe6eb6..d6c95afd4 100644
--- 
a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAMasterMetaManager.java
+++ 
b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAMasterMetaManager.java
@@ -109,6 +109,24 @@ public class HAMasterMetaManager extends 
AbstractMetaManager {
     }
   }
 
+  @Override
+  public void handleBatchUnRegisterShuffles(List<String> shuffleKeys, String 
requestId) {
+    try {
+      ratisServer.submitRequest(
+          ResourceRequest.newBuilder()
+              .setCmdType(Type.BatchUnRegisterShuffle)
+              .setRequestId(requestId)
+              .setBatchUnregisterShuffleRequest(
+                  ResourceProtos.BatchUnregisterShuffleRequest.newBuilder()
+                      .addAllShuffleKeys(shuffleKeys)
+                      .build())
+              .build());
+    } catch (CelebornRuntimeException e) {
+      LOG.error("Batch handle unregister shuffle for {} failed!", shuffleKeys, 
e);
+      throw e;
+    }
+  }
+
   @Override
   public void handleAppHeartbeat(
       String appId, long totalWritten, long fileCount, long time, String 
requestId) {
diff --git 
a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MetaHandler.java
 
b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MetaHandler.java
index bb16c1904..a7948dde7 100644
--- 
a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MetaHandler.java
+++ 
b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MetaHandler.java
@@ -118,6 +118,13 @@ public class MetaHandler {
           metaSystem.updateUnregisterShuffleMeta(shuffleKey);
           break;
 
+        case BatchUnRegisterShuffle:
+          List<String> shuffleKeys =
+              request.getBatchUnregisterShuffleRequest().getShuffleKeysList();
+          metaSystem.updateBatchUnregisterShuffleMeta(shuffleKeys);
+          LOG.debug("Handle batch unregister shuffle for {}", shuffleKeys);
+          break;
+
         case AppHeartbeat:
           appId = request.getAppHeartbeatRequest().getAppId();
           LOG.debug("Handle app heartbeat for {}", appId);
diff --git a/master/src/main/proto/Resource.proto 
b/master/src/main/proto/Resource.proto
index 2c129f61e..b83599409 100644
--- a/master/src/main/proto/Resource.proto
+++ b/master/src/main/proto/Resource.proto
@@ -40,6 +40,7 @@ enum Type {
   WorkerEvent = 25;
   ApplicationMeta = 26;
   ReportWorkerDecommission = 27;
+  BatchUnRegisterShuffle = 28;
 }
 
 enum WorkerEventType {
@@ -73,6 +74,7 @@ message ResourceRequest {
   optional WorkerEventRequest workerEventRequest = 22;
   optional ApplicationMetaRequest applicationMetaRequest = 23;
   optional ReportWorkerDecommissionRequest reportWorkerDecommissionRequest = 
24;
+  optional BatchUnregisterShuffleRequest batchUnregisterShuffleRequest = 25;
 }
 
 message DiskInfo {
@@ -106,6 +108,10 @@ message UnregisterShuffleRequest {
   required string shuffleKey = 1;
 }
 
+message BatchUnregisterShuffleRequest {
+  repeated string shuffleKeys = 1;
+}
+
 message AppHeartbeatRequest {
   required string appId = 1;
   required int64 time = 2;
diff --git 
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala 
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
index b9e4887eb..544c9656b 100644
--- 
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
+++ 
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
@@ -460,6 +460,16 @@ private[celeborn] class Master(
       checkAuth(context, applicationId)
       executeWithLeaderChecker(context, handleRequestSlots(context, 
requestSlots))
 
+    case pb: PbBatchUnregisterShuffles =>
+      val applicationId = pb.getAppId
+      val shuffleIds = pb.getShuffleIdsList.asScala.toList
+      val requestId = pb.getRequestId
+      logDebug(s"Received BatchUnregisterShuffle request $requestId, 
$applicationId, $shuffleIds")
+      checkAuth(context, applicationId)
+      executeWithLeaderChecker(
+        context,
+        batchHandleUnregisterShuffles(context, applicationId, shuffleIds, 
requestId))
+
     case pb: PbUnregisterShuffle =>
       val applicationId = pb.getAppId
       val shuffleId = pb.getShuffleId
@@ -982,6 +992,18 @@ private[celeborn] class Master(
     context.reply(UnregisterShuffleResponse(StatusCode.SUCCESS))
   }
 
+  def batchHandleUnregisterShuffles(
+      context: RpcCallContext,
+      applicationId: String,
+      shuffleIds: List[Integer],
+      requestId: String): Unit = {
+    val shuffleKeys =
+      shuffleIds.map(shuffleId => Utils.makeShuffleKey(applicationId, 
shuffleId)).asJava
+    statusSystem.handleBatchUnRegisterShuffles(shuffleKeys, requestId)
+    logInfo(s"BatchUnregister shuffle $shuffleKeys")
+    context.reply(BatchUnregisterShuffleResponse(StatusCode.SUCCESS, 
shuffleIds.asJava))
+  }
+
   private def handleReportNodeUnavailable(
       context: RpcCallContext,
       failedWorkers: util.List[WorkerInfo],
diff --git 
a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/DefaultMetaSystemSuiteJ.java
 
b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/DefaultMetaSystemSuiteJ.java
index 5964ba700..b571e49f6 100644
--- 
a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/DefaultMetaSystemSuiteJ.java
+++ 
b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/DefaultMetaSystemSuiteJ.java
@@ -552,6 +552,88 @@ public class DefaultMetaSystemSuiteJ {
     assertTrue(statusSystem.registeredShuffle.isEmpty());
   }
 
+  @Test
+  public void testHandleBatchUnRegisterShuffle() {
+    statusSystem.handleRegisterWorker(
+        HOSTNAME1,
+        RPCPORT1,
+        PUSHPORT1,
+        FETCHPORT1,
+        REPLICATEPORT1,
+        INTERNALPORT1,
+        NETWORK_LOCATION1,
+        disks1,
+        userResourceConsumption1,
+        getNewReqeustId());
+    statusSystem.handleRegisterWorker(
+        HOSTNAME2,
+        RPCPORT2,
+        PUSHPORT2,
+        FETCHPORT2,
+        REPLICATEPORT2,
+        INTERNALPORT2,
+        NETWORK_LOCATION2,
+        disks2,
+        userResourceConsumption2,
+        getNewReqeustId());
+    statusSystem.handleRegisterWorker(
+        HOSTNAME3,
+        RPCPORT3,
+        PUSHPORT3,
+        FETCHPORT3,
+        REPLICATEPORT3,
+        INTERNALPORT3,
+        NETWORK_LOCATION3,
+        disks3,
+        userResourceConsumption3,
+        getNewReqeustId());
+
+    WorkerInfo workerInfo1 =
+        new WorkerInfo(
+            HOSTNAME1,
+            RPCPORT1,
+            PUSHPORT1,
+            FETCHPORT1,
+            REPLICATEPORT1,
+            INTERNALPORT1,
+            disks1,
+            userResourceConsumption1);
+    WorkerInfo workerInfo2 =
+        new WorkerInfo(
+            HOSTNAME2,
+            RPCPORT2,
+            PUSHPORT2,
+            FETCHPORT2,
+            REPLICATEPORT2,
+            INTERNALPORT2,
+            disks2,
+            userResourceConsumption2);
+
+    Map<String, Map<String, Integer>> workersToAllocate = new HashMap<>();
+    Map<String, Integer> allocation = new HashMap<>();
+    allocation.put("disk1", 5);
+    workersToAllocate.put(workerInfo1.toUniqueId(), allocation);
+    workersToAllocate.put(workerInfo2.toUniqueId(), allocation);
+
+    List<String> shuffleKeysAll = new ArrayList<>();
+    for (int i = 1; i <= 4; i++) {
+      String shuffleKey = APPID1 + "-" + i;
+      shuffleKeysAll.add(shuffleKey);
+      statusSystem.handleRequestSlots(shuffleKey, HOSTNAME1, 
workersToAllocate, getNewReqeustId());
+    }
+    Assert.assertEquals(4, statusSystem.registeredShuffle.size());
+
+    List<String> shuffleKeys1 = new ArrayList<>();
+    shuffleKeys1.add(shuffleKeysAll.get(0));
+
+    statusSystem.handleBatchUnRegisterShuffles(shuffleKeys1, 
getNewReqeustId());
+    Assert.assertEquals(3, statusSystem.registeredShuffle.size());
+
+    statusSystem.handleBatchUnRegisterShuffles(shuffleKeysAll, 
getNewReqeustId());
+
+    Assert.assertTrue(statusSystem.registeredShuffle.isEmpty());
+  }
+
   @Test
   public void testHandleAppHeartbeat() {
     Long dummy = 1235L;
diff --git 
a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
 
b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
index 72949f5d1..851d87c1c 100644
--- 
a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
+++ 
b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
@@ -844,6 +844,102 @@ public class RatisMasterStatusSystemSuiteJ {
     Assert.assertTrue(STATUSSYSTEM3.registeredShuffle.isEmpty());
   }
 
+  @Test
+  public void testHandleBatchUnRegisterShuffle() throws InterruptedException {
+    AbstractMetaManager statusSystem = pickLeaderStatusSystem();
+    Assert.assertNotNull(statusSystem);
+
+    statusSystem.handleRegisterWorker(
+        HOSTNAME1,
+        RPCPORT1,
+        PUSHPORT1,
+        FETCHPORT1,
+        REPLICATEPORT1,
+        INTERNALPORT1,
+        NETWORK_LOCATION1,
+        disks1,
+        userResourceConsumption1,
+        getNewReqeustId());
+    statusSystem.handleRegisterWorker(
+        HOSTNAME2,
+        RPCPORT2,
+        PUSHPORT2,
+        FETCHPORT2,
+        REPLICATEPORT2,
+        INTERNALPORT2,
+        NETWORK_LOCATION2,
+        disks2,
+        userResourceConsumption2,
+        getNewReqeustId());
+    statusSystem.handleRegisterWorker(
+        HOSTNAME3,
+        RPCPORT3,
+        PUSHPORT3,
+        FETCHPORT3,
+        REPLICATEPORT3,
+        INTERNALPORT3,
+        NETWORK_LOCATION3,
+        disks3,
+        userResourceConsumption3,
+        getNewReqeustId());
+
+    WorkerInfo workerInfo1 =
+        new WorkerInfo(
+            HOSTNAME1,
+            RPCPORT1,
+            PUSHPORT1,
+            FETCHPORT1,
+            REPLICATEPORT1,
+            INTERNALPORT1,
+            disks1,
+            userResourceConsumption1);
+    WorkerInfo workerInfo2 =
+        new WorkerInfo(
+            HOSTNAME2,
+            RPCPORT2,
+            PUSHPORT2,
+            FETCHPORT2,
+            REPLICATEPORT2,
+            INTERNALPORT2,
+            disks2,
+            userResourceConsumption2);
+
+    Map<String, Map<String, Integer>> workersToAllocate = new HashMap<>();
+    Map<String, Integer> allocations = new HashMap<>();
+    allocations.put("disk1", 5);
+    workersToAllocate.put(workerInfo1.toUniqueId(), allocations);
+    workersToAllocate.put(workerInfo2.toUniqueId(), allocations);
+
+    List<String> shuffleKeysAll = new ArrayList<>();
+    for (int i = 1; i <= 4; i++) {
+      String shuffleKey = APPID1 + "-" + i;
+      shuffleKeysAll.add(shuffleKey);
+      statusSystem.handleRequestSlots(shuffleKey, HOSTNAME1, 
workersToAllocate, getNewReqeustId());
+    }
+
+    Thread.sleep(3000L);
+
+    Assert.assertEquals(4, STATUSSYSTEM1.registeredShuffle.size());
+    Assert.assertEquals(4, STATUSSYSTEM2.registeredShuffle.size());
+    Assert.assertEquals(4, STATUSSYSTEM3.registeredShuffle.size());
+
+    List<String> shuffleKeys1 = new ArrayList<>();
+    shuffleKeys1.add(shuffleKeysAll.get(0));
+
+    statusSystem.handleBatchUnRegisterShuffles(shuffleKeys1, 
getNewReqeustId());
+    Thread.sleep(3000L);
+    Assert.assertEquals(3, STATUSSYSTEM1.registeredShuffle.size());
+    Assert.assertEquals(3, STATUSSYSTEM2.registeredShuffle.size());
+    Assert.assertEquals(3, STATUSSYSTEM3.registeredShuffle.size());
+
+    statusSystem.handleBatchUnRegisterShuffles(shuffleKeysAll, 
getNewReqeustId());
+    Thread.sleep(3000L);
+
+    Assert.assertTrue(STATUSSYSTEM1.registeredShuffle.isEmpty());
+    Assert.assertTrue(STATUSSYSTEM2.registeredShuffle.isEmpty());
+    Assert.assertTrue(STATUSSYSTEM3.registeredShuffle.isEmpty());
+  }
+
   @Test
   public void testHandleAppHeartbeat() throws InterruptedException {
     AbstractMetaManager statusSystem = pickLeaderStatusSystem();
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerUnregisterShuffleSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerUnregisterShuffleSuite.scala
new file mode 100644
index 000000000..3b7d74df8
--- /dev/null
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerUnregisterShuffleSuite.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.client
+
+import java.util
+
+import scala.collection.JavaConverters._
+
+import org.scalatest.concurrent.Eventually.eventually
+import org.scalatest.concurrent.Futures.{interval, timeout}
+import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
+
+import org.apache.celeborn.client.{LifecycleManager, WithShuffleClientSuite}
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.protocol.message.StatusCode
+import org.apache.celeborn.common.util.Utils
+import org.apache.celeborn.service.deploy.MiniClusterFeature
+
+class LifecycleManagerUnregisterShuffleSuite extends WithShuffleClientSuite
+  with MiniClusterFeature {
+
+  celebornConf
+    .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "true")
+    .set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "256K")
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    val (master, _) = setupMiniClusterWithRandomPorts()
+    celebornConf.set(
+      CelebornConf.MASTER_ENDPOINTS.key,
+      master.conf.get(CelebornConf.MASTER_ENDPOINTS.key))
+  }
+
+  test("test unregister shuffle in batch") {
+    val conf = celebornConf.clone
+    conf.set(CelebornConf.CLIENT_BATCH_REMOVE_EXPIRED_SHUFFLE_ENABLED.key, 
"true")
+    val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf)
+    val counts = 10
+    val ids =
+      new util.ArrayList[Integer]((0 until counts).toList.map(x => 
Integer.valueOf(x)).asJava)
+    val shuffleIds = (1 to counts).toList
+
+    shuffleIds.foreach { shuffleId: Int =>
+      val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 
ids)
+      assert(res.status == StatusCode.SUCCESS)
+      lifecycleManager.registeredShuffle.add(shuffleId)
+      assert(lifecycleManager.registeredShuffle.contains(shuffleId))
+      val shuffleKey = Utils.makeShuffleKey(APP, shuffleId)
+      assert(masterInfo._1.statusSystem.registeredShuffle.contains(shuffleKey))
+      lifecycleManager.commitManager.setStageEnd(shuffleId)
+    }
+
+    shuffleIds.foreach { shuffleId: Int =>
+      lifecycleManager.unregisterShuffle(shuffleId)
+    }
+    // after unregister shuffle
+    eventually(timeout(120.seconds), interval(2.seconds)) {
+      shuffleIds.foreach { shuffleId: Int =>
+        val shuffleKey = Utils.makeShuffleKey(APP, shuffleId)
+        assert(!lifecycleManager.registeredShuffle.contains(shuffleId))
+        
assert(!masterInfo._1.statusSystem.registeredShuffle.contains(shuffleKey))
+      }
+    }
+
+    lifecycleManager.stop()
+  }
+
+  test("test unregister shuffle") {
+    val conf = celebornConf.clone
+    val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf)
+    val counts = 10
+    val ids =
+      new util.ArrayList[Integer]((0 until counts).toList.map(x => 
Integer.valueOf(x)).asJava)
+    val shuffleIds = (1 to counts).toList
+
+    shuffleIds.foreach { shuffleId: Int =>
+      val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 
ids)
+      assert(res.status == StatusCode.SUCCESS)
+      lifecycleManager.registeredShuffle.add(shuffleId)
+      assert(lifecycleManager.registeredShuffle.contains(shuffleId))
+      val shuffleKey = Utils.makeShuffleKey(APP, shuffleId)
+      assert(masterInfo._1.statusSystem.registeredShuffle.contains(shuffleKey))
+      lifecycleManager.commitManager.setStageEnd(shuffleId)
+    }
+    val previousTime = System.currentTimeMillis()
+    shuffleIds.foreach { shuffleId: Int =>
+      lifecycleManager.unregisterShuffle(shuffleId)
+    }
+    // after unregister shuffle
+    eventually(timeout(120.seconds), interval(2.seconds)) {
+      shuffleIds.foreach { shuffleId: Int =>
+        val shuffleKey = Utils.makeShuffleKey(APP, shuffleId)
+        assert(!lifecycleManager.registeredShuffle.contains(shuffleId))
+        
assert(!masterInfo._1.statusSystem.registeredShuffle.contains(shuffleKey))
+      }
+    }
+    val currentTime = System.currentTimeMillis()
+    assert(currentTime - previousTime > conf.shuffleExpiredCheckIntervalMs)
+    lifecycleManager.stop()
+  }
+
+  override def afterAll(): Unit = {
+    logInfo("all test complete , stop celeborn mini cluster")
+    shutdownMiniCluster()
+  }
+}

Reply via email to