This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 7685fa7db [CELEBORN-1636] Client supports dynamic update of Worker 
resources on the server
7685fa7db is described below

commit 7685fa7db22a156d42f8824192ccd6264d351de7
Author: szt <[email protected]>
AuthorDate: Mon Oct 28 09:49:31 2024 +0800

    [CELEBORN-1636] Client supports dynamic update of Worker resources on the 
server
    
    ### What changes were proposed in this pull request?
    Currently, the ChangePartitionManager retrieves workers from the 
LifeCycleManager's workerSnapshot. However, during the revival process in 
reallocateChangePartitionRequestSlotsFromCandidates, it does not account for 
newly added available workers resulting from elastic contraction and expansion. 
This PR addresses this issue by updating the candidate workers in the 
ChangePartitionManager to use the available workers reported in the heartbeat 
from the master.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    Closes #2835 from zaynt4606/clbdev.
    
    Authored-by: szt <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../celeborn/client/ApplicationHeartbeater.scala   |   3 +
 .../celeborn/client/ChangePartitionManager.scala   | 126 +++++++----
 .../apache/celeborn/client/LifecycleManager.scala  |   8 +-
 .../celeborn/client/WorkerStatusTracker.scala      |  61 +++++-
 .../celeborn/client/WorkerStatusTrackerSuite.scala | 121 ++++++++++-
 common/src/main/proto/TransportMessages.proto      |   2 +
 .../org/apache/celeborn/common/CelebornConf.scala  |  11 +
 .../common/protocol/message/ControlMessages.scala  |  10 +
 docs/configuration/client.md                       |   1 +
 .../celeborn/service/deploy/master/Master.scala    |  10 +
 .../ChangePartitionManagerUpdateWorkersSuite.scala | 232 +++++++++++++++++++++
 .../client/LifecycleManagerCommitFilesSuite.scala  |  10 +-
 .../client/LifecycleManagerDestroySlotsSuite.scala |  15 +-
 .../LifecycleManagerSetupEndpointSuite.scala       |   4 +-
 .../celeborn/tests/spark/RetryReviveTest.scala     |  30 ++-
 .../service/deploy/MiniClusterFeature.scala        |  22 +-
 16 files changed, 589 insertions(+), 77 deletions(-)

diff --git 
a/client/src/main/scala/org/apache/celeborn/client/ApplicationHeartbeater.scala 
b/client/src/main/scala/org/apache/celeborn/client/ApplicationHeartbeater.scala
index 0b850ff52..4b90a8bca 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/ApplicationHeartbeater.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/ApplicationHeartbeater.scala
@@ -47,6 +47,7 @@ class ApplicationHeartbeater(
   // Use independent app heartbeat threads to avoid being blocked by other 
operations.
   private val appHeartbeatIntervalMs = conf.appHeartbeatIntervalMs
   private val applicationUnregisterEnabled = conf.applicationUnregisterEnabled
+  private val clientShuffleDynamicResourceEnabled = 
conf.clientShuffleDynamicResourceEnabled
   private val appHeartbeatHandlerThread =
     ThreadUtils.newDaemonSingleThreadScheduledExecutor(
       "celeborn-client-lifecycle-manager-app-heartbeater")
@@ -69,6 +70,7 @@ class ApplicationHeartbeater(
                 tmpTotalWritten,
                 tmpTotalFileCount,
                 workerStatusTracker.getNeedCheckedWorkers().toList.asJava,
+                clientShuffleDynamicResourceEnabled,
                 ZERO_UUID,
                 true)
             val response = requestHeartbeat(appHeartbeat)
@@ -129,6 +131,7 @@ class ApplicationHeartbeater(
           List.empty.asJava,
           List.empty.asJava,
           List.empty.asJava,
+          List.empty.asJava,
           List.empty.asJava)
     }
   }
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
index 46628aaab..771b51151 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
@@ -18,14 +18,15 @@
 package org.apache.celeborn.client
 
 import java.util
-import java.util.{Set => JSet}
+import java.util.{function, Set => JSet}
 import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, 
ScheduledFuture, TimeUnit}
 
 import scala.collection.JavaConverters._
 
+import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.meta.WorkerInfo
+import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, 
WorkerInfo}
 import org.apache.celeborn.common.protocol.PartitionLocation
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource
 import org.apache.celeborn.common.protocol.message.StatusCode
@@ -45,7 +46,8 @@ class ChangePartitionManager(
 
   private val pushReplicateEnabled = conf.clientPushReplicateEnabled
   // shuffleId -> (partitionId -> set of ChangePartition)
-  private val changePartitionRequests =
+  val changePartitionRequests
+      : ConcurrentHashMap[Int, ConcurrentHashMap[Integer, 
JSet[ChangePartitionRequest]]] =
     JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Integer, 
JSet[ChangePartitionRequest]]]()
 
   // shuffleId -> locks
@@ -74,6 +76,8 @@ class ChangePartitionManager(
 
   private val testRetryRevive = conf.testRetryRevive
 
+  private val clientShuffleDynamicResourceEnabled = 
conf.clientShuffleDynamicResourceEnabled
+
   def start(): Unit = {
     batchHandleChangePartition = batchHandleChangePartitionSchedulerThread.map 
{
       // noinspection ConvertExpressionToSAM
@@ -128,7 +132,8 @@ class ChangePartitionManager(
     batchHandleChangePartitionSchedulerThread.foreach(ThreadUtils.shutdown(_))
   }
 
-  private val rpcContextRegisterFunc =
+  val rpcContextRegisterFunc
+      : function.Function[Int, ConcurrentHashMap[Integer, 
JSet[ChangePartitionRequest]]] =
     new util.function.Function[
       Int,
       ConcurrentHashMap[Integer, util.Set[ChangePartitionRequest]]]() {
@@ -148,6 +153,13 @@ class ChangePartitionManager(
     }
   }
 
+  private val updateWorkerSnapshotsFunc =
+    new util.function.Function[WorkerInfo, ShufflePartitionLocationInfo] {
+      override def apply(w: WorkerInfo): ShufflePartitionLocationInfo = {
+        new ShufflePartitionLocationInfo()
+      }
+    }
+
   def handleRequestPartitionLocation(
       context: RequestLocationCallContext,
       shuffleId: Int,
@@ -186,7 +198,7 @@ class ChangePartitionManager(
             partitionId,
             StatusCode.SUCCESS,
             Some(latestLoc),
-            lifecycleManager.workerStatusTracker.workerAvailable(oldPartition))
+            
lifecycleManager.workerStatusTracker.workerAvailableByLocation(oldPartition))
           logDebug(s"[handleRequestPartitionLocation]: For shuffle: 
$shuffleId," +
             s" old partition: $partitionId-$oldEpoch, new partition: 
$latestLoc found, return it")
           return
@@ -254,7 +266,7 @@ class ChangePartitionManager(
             req.partitionId,
             StatusCode.SUCCESS,
             Option(newLocation),
-            
lifecycleManager.workerStatusTracker.workerAvailable(req.oldPartition))))
+            
lifecycleManager.workerStatusTracker.workerAvailableByLocation(req.oldPartition))))
       }
     }
 
@@ -274,18 +286,49 @@ class ChangePartitionManager(
             req.partitionId,
             status,
             None,
-            
lifecycleManager.workerStatusTracker.workerAvailable(req.oldPartition))))
+            
lifecycleManager.workerStatusTracker.workerAvailableByLocation(req.oldPartition))))
       }
     }
 
-    // Get candidate worker that not in excluded worker list of shuffleId
-    val candidates =
-      lifecycleManager
-        .workerSnapshots(shuffleId)
-        .keySet()
+    val candidates = new util.HashSet[WorkerInfo]()
+    if (clientShuffleDynamicResourceEnabled) {
+      // availableWorkers wont filter excludedWorkers in heartBeat So have to 
do filtering.
+      candidates.addAll(lifecycleManager
+        .workerStatusTracker
+        .availableWorkersWithEndpoint
+        .values()
         .asScala
+        .toSet
         .filter(lifecycleManager.workerStatusTracker.workerAvailable)
-        .toList
+        .asJava)
+
+      // SetupEndpoint for those availableWorkers without endpoint
+      val workersRequireEndpoints = new util.HashSet[WorkerInfo](
+        
lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.asScala.filter(
+          lifecycleManager.workerStatusTracker.workerAvailable).asJava)
+      val connectFailedWorkers = new ShuffleFailedWorkers()
+      lifecycleManager.setupEndpoints(
+        workersRequireEndpoints,
+        shuffleId,
+        connectFailedWorkers)
+      
workersRequireEndpoints.removeAll(connectFailedWorkers.asScala.keys.toList.asJava)
+      candidates.addAll(workersRequireEndpoints)
+
+      // Update worker status
+      
lifecycleManager.workerStatusTracker.addWorkersWithEndpoint(workersRequireEndpoints)
+      
lifecycleManager.workerStatusTracker.recordWorkerFailure(connectFailedWorkers)
+      
lifecycleManager.workerStatusTracker.removeFromExcludedWorkers(candidates)
+    } else {
+      val snapshotCandidates =
+        lifecycleManager
+          .workerSnapshots(shuffleId)
+          .keySet()
+          .asScala
+          .filter(lifecycleManager.workerStatusTracker.workerAvailable)
+          .asJava
+      candidates.addAll(snapshotCandidates)
+    }
+
     if (candidates.size < 1 || (pushReplicateEnabled && candidates.size < 2)) {
       logError("[Update partition] failed for not enough candidates for 
revive.")
       replyFailure(StatusCode.SLOT_NOT_AVAILABLE)
@@ -294,11 +337,13 @@ class ChangePartitionManager(
 
     // PartitionSplit all contains oldPartition
     val newlyAllocatedLocations =
-      
reallocateChangePartitionRequestSlotsFromCandidates(changePartitions.toList, 
candidates)
+      reallocateChangePartitionRequestSlotsFromCandidates(
+        changePartitions.toList,
+        candidates.asScala.toList)
 
     if (!lifecycleManager.reserveSlotsWithRetry(
         shuffleId,
-        new util.HashSet(candidates.toSet.asJava),
+        candidates,
         newlyAllocatedLocations,
         isSegmentGranularityVisible = isSegmentGranularityVisible)) {
       logError(s"[Update partition] failed for $shuffleId.")
@@ -306,33 +351,32 @@ class ChangePartitionManager(
       return
     }
 
-    val newPrimaryLocations =
-      newlyAllocatedLocations.asScala.flatMap {
-        case (workInfo, (primaryLocations, replicaLocations)) =>
-          // Add all re-allocated slots to worker snapshots.
-          lifecycleManager.workerSnapshots(shuffleId).asScala
-            .get(workInfo)
-            .foreach { partitionLocationInfo =>
-              partitionLocationInfo.addPrimaryPartitions(primaryLocations)
-              lifecycleManager.updateLatestPartitionLocations(shuffleId, 
primaryLocations)
-              partitionLocationInfo.addReplicaPartitions(replicaLocations)
-            }
-          // partition location can be null when call reserveSlotsWithRetry().
-          val locations = (primaryLocations.asScala ++ 
replicaLocations.asScala.map(_.getPeer))
-            .distinct.filter(_ != null)
-          if (locations.nonEmpty) {
-            val changes = locations.map { partition =>
-              s"(partition ${partition.getId} epoch from ${partition.getEpoch 
- 1} to ${partition.getEpoch})"
-            }.mkString("[", ", ", "]")
-            logInfo(s"[Update partition] success for " +
-              s"shuffle $shuffleId, succeed partitions: " +
-              s"$changes.")
-          }
+    val newPrimaryLocations = newlyAllocatedLocations.asScala.flatMap {
+      case (workInfo, (primaryLocations, replicaLocations)) =>
+        // Add all re-allocated slots to worker snapshots.
+        val partitionLocationInfo = 
lifecycleManager.workerSnapshots(shuffleId).computeIfAbsent(
+          workInfo,
+          updateWorkerSnapshotsFunc)
+        partitionLocationInfo.addPrimaryPartitions(primaryLocations)
+        partitionLocationInfo.addReplicaPartitions(replicaLocations)
+        lifecycleManager.updateLatestPartitionLocations(shuffleId, 
primaryLocations)
+
+        // partition location can be null when call reserveSlotsWithRetry().
+        val locations = (primaryLocations.asScala ++ 
replicaLocations.asScala.map(_.getPeer))
+          .distinct.filter(_ != null)
+        if (locations.nonEmpty) {
+          val changes = locations.map { partition =>
+            s"(partition ${partition.getId} epoch from ${partition.getEpoch - 
1} to ${partition.getEpoch})"
+          }.mkString("[", ", ", "]")
+          logInfo(s"[Update partition] success for " +
+            s"shuffle $shuffleId, succeed partitions: " +
+            s"$changes.")
+        }
 
-          // TODO: should record the new partition locations and acknowledge 
the new partitionLocations to downstream task,
-          //  in scenario the downstream task start early before the upstream 
task.
-          locations
-      }
+        // TODO: should record the new partition locations and acknowledge the 
new partitionLocations to downstream task,
+        //  in scenario the downstream task start early before the upstream 
task.
+        locations
+    }
     replySuccess(newPrimaryLocations.toArray)
   }
 
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index d35802958..38b9f28ba 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -449,11 +449,11 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
   }
 
   def setupEndpoints(
-      slots: WorkerResource,
+      workers: util.Set[WorkerInfo],
       shuffleId: Int,
       connectFailedWorkers: ShuffleFailedWorkers): Unit = {
     val futures = new util.LinkedList[(Future[RpcEndpointRef], WorkerInfo)]()
-    slots.asScala foreach { case (workerInfo, _) =>
+    workers.asScala foreach { workerInfo =>
       val future = 
workerRpcEnvInUse.asyncSetupEndpointRefByAddr(RpcEndpointAddress(
         RpcAddress.apply(workerInfo.host, workerInfo.rpcPort),
         WORKER_EP))
@@ -676,8 +676,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
     val connectFailedWorkers = new ShuffleFailedWorkers()
 
     // Second, for each worker, try to initialize the endpoint.
-    setupEndpoints(slots, shuffleId, connectFailedWorkers)
-
+    setupEndpoints(slots.keySet(), shuffleId, connectFailedWorkers)
     
candidatesWorkers.removeAll(connectFailedWorkers.asScala.keys.toList.asJava)
     workerStatusTracker.recordWorkerFailure(connectFailedWorkers)
     // If newly allocated from primary and can setup endpoint success, 
LifecycleManager should remove worker from
@@ -713,6 +712,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
         allocatedWorkers.put(workerInfo, partitionLocationInfo)
       }
       shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
+      workerStatusTracker.addWorkersWithEndpoint(candidatesWorkers)
       registeredShuffle.add(shuffleId)
       commitManager.registerShuffle(
         shuffleId,
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/WorkerStatusTracker.scala 
b/client/src/main/scala/org/apache/celeborn/client/WorkerStatusTracker.scala
index d7ccf0fe8..088341813 100644
--- a/client/src/main/scala/org/apache/celeborn/client/WorkerStatusTracker.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/WorkerStatusTracker.scala
@@ -30,17 +30,27 @@ import org.apache.celeborn.common.meta.WorkerInfo
 import org.apache.celeborn.common.protocol.PartitionLocation
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.HeartbeatFromApplicationResponse
 import org.apache.celeborn.common.protocol.message.StatusCode
-import org.apache.celeborn.common.util.Utils
+import org.apache.celeborn.common.util.{JavaUtils, Utils}
 
 class WorkerStatusTracker(
     conf: CelebornConf,
     lifecycleManager: LifecycleManager) extends Logging {
   private val excludedWorkerExpireTimeout = 
conf.clientExcludedWorkerExpireTimeout
   private val workerStatusListeners = 
ConcurrentHashMap.newKeySet[WorkerStatusListener]()
+  private val clientShuffleDynamicResourceEnabled = 
conf.clientShuffleDynamicResourceEnabled
 
   val excludedWorkers = new ShuffleFailedWorkers()
   val shuttingWorkers: JSet[WorkerInfo] = new JHashSet[WorkerInfo]()
 
+  // Workers that have already set an endpoint can skip the setupEndpoint 
process in changePartition when reviving
+  // key: WorkerInfo.toUniqueId value: WorkerInfo
+  val availableWorkersWithEndpoint: ConcurrentHashMap[String, WorkerInfo] =
+    JavaUtils.newConcurrentHashMap[String, WorkerInfo]()
+
+  // Workers that may be available but have not been used(without endpoint)
+  // availableWorkersWithoutEndpoint is empty until 
appHeartbeatWithAvailableWorkers set to true
+  val availableWorkersWithoutEndpoint = 
ConcurrentHashMap.newKeySet[WorkerInfo]()
+
   def registerWorkerStatusListener(workerStatusListener: 
WorkerStatusListener): Unit = {
     workerStatusListeners.add(workerStatusListener)
   }
@@ -61,7 +71,7 @@ class WorkerStatusTracker(
     !excludedWorkers.containsKey(worker) && !shuttingWorkers.contains(worker)
   }
 
-  def workerAvailable(loc: PartitionLocation): Boolean = {
+  def workerAvailableByLocation(loc: PartitionLocation): Boolean = {
     if (loc == null) {
       false
     } else {
@@ -131,13 +141,16 @@ class WorkerStatusTracker(
       failedWorkers.asScala.foreach {
         case (worker, (StatusCode.WORKER_SHUTDOWN, _)) =>
           shuttingWorkers.add(worker)
+          removeFromAvailableWorkers(worker)
         case (worker, (statusCode, registerTime)) if 
!excludedWorkers.containsKey(worker) =>
           excludedWorkers.put(worker, (statusCode, registerTime))
+          removeFromAvailableWorkers(worker)
         case (worker, (statusCode, _))
             if statusCode == StatusCode.NO_AVAILABLE_WORKING_DIR ||
               statusCode == StatusCode.RESERVE_SLOTS_FAILED ||
               statusCode == StatusCode.WORKER_UNKNOWN =>
           excludedWorkers.put(worker, (statusCode, 
excludedWorkers.get(worker)._2))
+          removeFromAvailableWorkers(worker)
         case _ => // Not cover
       }
     }
@@ -147,10 +160,22 @@ class WorkerStatusTracker(
     excludedWorkers.keySet.removeAll(workers)
   }
 
+  private def removeFromAvailableWorkers(worker: WorkerInfo): Unit = {
+    availableWorkersWithEndpoint.remove(worker.toUniqueId())
+    availableWorkersWithoutEndpoint.remove(worker)
+  }
+
+  def addWorkersWithEndpoint(workers: JHashSet[WorkerInfo]): Unit = {
+    availableWorkersWithoutEndpoint.removeAll(workers)
+    workers.asScala.foreach { workerInfo =>
+      availableWorkersWithEndpoint.put(workerInfo.toUniqueId(), workerInfo)
+    }
+  }
+
   def handleHeartbeatResponse(res: HeartbeatFromApplicationResponse): Unit = {
     if (res.statusCode == StatusCode.SUCCESS) {
       logDebug(s"Received Worker status from Primary, excluded workers: 
${res.excludedWorkers} " +
-        s"unknown workers: ${res.unknownWorkers}, shutdown workers: 
${res.shuttingWorkers}")
+        s"unknown workers: ${res.unknownWorkers}, shutdown workers: 
${res.shuttingWorkers}, available workers from heartbeat: 
${res.availableWorkers}")
       val current = System.currentTimeMillis()
       var statusChanged = false
 
@@ -188,9 +213,33 @@ class WorkerStatusTracker(
           statusChanged = true
         }
       }
-      val retainResult = shuttingWorkers.retainAll(res.shuttingWorkers)
-      val addResult = shuttingWorkers.addAll(res.shuttingWorkers)
-      statusChanged = statusChanged || retainResult || addResult
+
+      val retainShuttingWorkersResult = 
shuttingWorkers.retainAll(res.shuttingWorkers)
+      val addShuttingWorkersResult = 
shuttingWorkers.addAll(res.shuttingWorkers)
+
+      if (clientShuffleDynamicResourceEnabled) {
+        // AvailableWorkers filter Client excludedWorkers and shuttingWorkers.
+        // AvailableWorkers already filtered res.excludedWorkers and 
res.shuttingWorkers.
+        val resAvailableWorkers: JSet[WorkerInfo] = new 
JHashSet[WorkerInfo](res.availableWorkers)
+        // update availableWorkers
+        // availableWorkers wont filter excludedWorkers.
+        // So before using them we hava to filter excludedWorkers.
+        availableWorkersWithoutEndpoint.retainAll(resAvailableWorkers)
+        availableWorkersWithEndpoint.keySet().retainAll(
+          resAvailableWorkers.asScala.map(_.toUniqueId()).asJava)
+        resAvailableWorkers.asScala.foreach { workerInfo: WorkerInfo =>
+          if 
(!availableWorkersWithEndpoint.keySet.contains(workerInfo.toUniqueId())) {
+            availableWorkersWithoutEndpoint.add(workerInfo)
+          } else {
+            if (availableWorkersWithoutEndpoint.contains(workerInfo)) {
+              availableWorkersWithoutEndpoint.remove(workerInfo)
+            }
+          }
+        }
+      }
+
+      statusChanged =
+        statusChanged || retainShuttingWorkersResult || 
addShuttingWorkersResult
       // Always trigger commit files for shutting down workers from 
HeartbeatFromApplicationResponse
       // See details in CELEBORN-696
       if (!res.unknownWorkers.isEmpty || !res.shuttingWorkers.isEmpty) {
diff --git 
a/client/src/test/scala/org/apache/celeborn/client/WorkerStatusTrackerSuite.scala
 
b/client/src/test/scala/org/apache/celeborn/client/WorkerStatusTrackerSuite.scala
index d2c21245d..27196e8d9 100644
--- 
a/client/src/test/scala/org/apache/celeborn/client/WorkerStatusTrackerSuite.scala
+++ 
b/client/src/test/scala/org/apache/celeborn/client/WorkerStatusTrackerSuite.scala
@@ -23,16 +23,16 @@ import org.junit.Assert
 
 import org.apache.celeborn.CelebornFunSuite
 import org.apache.celeborn.common.CelebornConf
-import 
org.apache.celeborn.common.CelebornConf.CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT
+import 
org.apache.celeborn.common.CelebornConf.{CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT, 
CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED}
 import org.apache.celeborn.common.meta.WorkerInfo
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.HeartbeatFromApplicationResponse
 import org.apache.celeborn.common.protocol.message.StatusCode
 
 class WorkerStatusTrackerSuite extends CelebornFunSuite {
-
-  test("handleHeartbeatResponse") {
+  test("handleHeartbeatResponse without availableWorkers") {
     val celebornConf = new CelebornConf()
     celebornConf.set(CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT, 2000L)
+    celebornConf.set(CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED, false)
     val statusTracker = new WorkerStatusTracker(celebornConf, null)
 
     val registerTime = System.currentTimeMillis()
@@ -40,7 +40,7 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
     statusTracker.excludedWorkers.put(mock("host2"), 
(StatusCode.WORKER_SHUTDOWN, registerTime))
 
     // test reserve (only statusCode list in handleHeartbeatResponse)
-    val empty = buildResponse(Array.empty, Array.empty, Array.empty)
+    val empty = buildResponse(Array.empty, Array.empty, Array.empty, 
Array.empty)
     statusTracker.handleHeartbeatResponse(empty)
 
     // only reserve host1
@@ -50,7 +50,8 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
     
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host2")))
 
     // add shutdown/excluded worker
-    val response1 = buildResponse(Array("host0"), Array("host1", "host3"), 
Array("host4"))
+    val response1 =
+      buildResponse(Array("host0"), Array("host1", "host3"), Array("host4"), 
Array.empty)
     statusTracker.handleHeartbeatResponse(response1)
 
     // test keep Unknown register time
@@ -58,15 +59,15 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
       statusTracker.excludedWorkers.get(mock("host1")),
       (StatusCode.WORKER_UNKNOWN, registerTime))
 
-    // test new added workers
+    // test new added shutdown/excluded workers
     Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host0")))
     Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host3")))
     
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
     Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))
 
     // test re heartbeat with shutdown workers
-    val response3 = buildResponse(Array.empty, Array.empty, Array("host4"))
-    statusTracker.handleHeartbeatResponse(response3)
+    val response2 = buildResponse(Array.empty, Array.empty, Array("host4"), 
Array.empty)
+    statusTracker.handleHeartbeatResponse(response2)
     
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
     Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))
 
@@ -78,24 +79,124 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
 
     // test register time elapsed
     Thread.sleep(3000)
-    val response2 = buildResponse(Array.empty, Array("host5", "host6"), 
Array.empty)
+    val response3 = buildResponse(Array.empty, Array("host5", "host6"), 
Array.empty, Array.empty)
+    statusTracker.handleHeartbeatResponse(response3)
+    Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
+    
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host1")))
+
+    // test available workers
+    Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(), 
0)
+    val response4 = buildResponse(
+      Array.empty,
+      Array.empty,
+      Array.empty,
+      Array("host5", "host6", "host7", "host8"))
+    statusTracker.handleHeartbeatResponse(response4)
+
+    // availableWorkers wont update through heartbeat
+    // when DYNAMIC_RESOURCE_ENABLE set to false
+    Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(), 
0)
+    // available workers won't overwrite excluded workers
+    Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
+    Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host5")))
+    Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host6")))
+  }
+
+  test("handleHeartbeatResponse with availableWorkers") {
+    val celebornConf = new CelebornConf()
+    celebornConf.set(CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT, 2000L)
+    celebornConf.set(CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED, true)
+    val statusTracker = new WorkerStatusTracker(celebornConf, null)
+
+    val registerTime = System.currentTimeMillis()
+    statusTracker.excludedWorkers.put(mock("host1"), 
(StatusCode.WORKER_UNKNOWN, registerTime))
+    statusTracker.excludedWorkers.put(mock("host2"), 
(StatusCode.WORKER_SHUTDOWN, registerTime))
+
+    // test reserve (only statusCode list in handleHeartbeatResponse)
+    val empty = buildResponse(Array.empty, Array.empty, Array.empty, 
Array.empty)
+    statusTracker.handleHeartbeatResponse(empty)
+
+    // only reserve host1
+    Assert.assertEquals(
+      statusTracker.excludedWorkers.get(mock("host1")),
+      (StatusCode.WORKER_UNKNOWN, registerTime))
+    
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host2")))
+
+    // add shutdown/excluded worker
+    val response1 =
+      buildResponse(Array("host0"), Array("host1", "host3"), Array("host4"), 
Array.empty)
+    statusTracker.handleHeartbeatResponse(response1)
+
+    // test keep Unknown register time
+    Assert.assertEquals(
+      statusTracker.excludedWorkers.get(mock("host1")),
+      (StatusCode.WORKER_UNKNOWN, registerTime))
+    // test new added workers
+    Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host0")))
+    Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host3")))
+    
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
+    Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))
+
+    // test re heartbeat with shutdown workers
+    val response2 = buildResponse(Array.empty, Array.empty, Array("host4"), 
Array.empty)
     statusTracker.handleHeartbeatResponse(response2)
+    
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
+    Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))
+
+    // test remove
+    val workers = new util.HashSet[WorkerInfo]
+    workers.add(mock("host3"))
+    statusTracker.removeFromExcludedWorkers(workers)
+    
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host3")))
+
+    // test register time elapsed
+    Thread.sleep(3000)
+    val response3 = buildResponse(Array.empty, Array("host5", "host6"), 
Array.empty, Array.empty)
+    statusTracker.handleHeartbeatResponse(response3)
     Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
     
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host1")))
+
+    // test available workers
+    Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(), 
0)
+    val response4 = buildResponse(
+      Array.empty,
+      Array.empty,
+      Array.empty,
+      Array("host5", "host6", "host7", "host8"))
+    statusTracker.handleHeartbeatResponse(response4)
+
+    // availableWorkers wont update with excludedWorkers
+    // So before using them we hava to filter excludedWorkers
+    Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(), 
4)
+    // available workers won't overwrite excluded workers
+    Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
+    Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host5")))
+    Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host6")))
+
+    // test re heartbeat with available workers
+    val response5 = buildResponse(Array.empty, Array.empty, Array.empty, 
Array("host8", "host9"))
+    statusTracker.handleHeartbeatResponse(response5)
+    Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(), 
2)
+    
Assert.assertFalse(statusTracker.availableWorkersWithoutEndpoint.contains(mock("host7")))
+    
Assert.assertTrue(statusTracker.availableWorkersWithoutEndpoint.contains(mock("host8")))
+    
Assert.assertTrue(statusTracker.availableWorkersWithoutEndpoint.contains(mock("host9")))
   }
 
   private def buildResponse(
       excludedWorkerHosts: Array[String],
       unknownWorkerHosts: Array[String],
-      shuttingWorkerHosts: Array[String]): HeartbeatFromApplicationResponse = {
+      shuttingWorkerHosts: Array[String],
+      availableWorkerHosts: Array[String]): HeartbeatFromApplicationResponse = 
{
     val excludedWorkers = mockWorkers(excludedWorkerHosts)
     val unknownWorkers = mockWorkers(unknownWorkerHosts)
     val shuttingWorkers = mockWorkers(shuttingWorkerHosts)
+    val availableWorkers = mockWorkers(availableWorkerHosts)
     HeartbeatFromApplicationResponse(
       StatusCode.SUCCESS,
       excludedWorkers,
       unknownWorkers,
       shuttingWorkers,
+      availableWorkers,
       new util.ArrayList[Integer]())
   }
 
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index bc8da08ab..d5b129b0d 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -442,6 +442,7 @@ message PbHeartbeatFromApplication {
   string requestId = 4;
   repeated PbWorkerInfo needCheckedWorkerList = 5;
   bool shouldResponse = 6;
+  bool needAvailableWorkers = 7;
 }
 
 message PbHeartbeatFromApplicationResponse {
@@ -450,6 +451,7 @@ message PbHeartbeatFromApplicationResponse {
   repeated PbWorkerInfo unknownWorkers = 3;
   repeated PbWorkerInfo shuttingWorkers = 4;
   repeated int32 registeredShuffles = 5;
+  repeated PbWorkerInfo availableWorkers = 6;
 }
 
 message PbCheckQuota {
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 9693146c3..97befab96 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -898,6 +898,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable 
with Logging with Se
   def clientReserveSlotsRetryWait: Long = get(CLIENT_RESERVE_SLOTS_RETRY_WAIT)
   def clientRequestCommitFilesMaxRetries: Int = 
get(CLIENT_COMMIT_FILE_REQUEST_MAX_RETRY)
   def clientCommitFilesIgnoreExcludedWorkers: Boolean = 
get(CLIENT_COMMIT_IGNORE_EXCLUDED_WORKERS)
+  def clientShuffleDynamicResourceEnabled: Boolean =
+    get(CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED)
   def appHeartbeatTimeoutMs: Long = get(APPLICATION_HEARTBEAT_TIMEOUT)
   def hdfsExpireDirsTimeoutMS: Long = get(HDFS_EXPIRE_DIRS_TIMEOUT)
   def dfsExpireDirsTimeoutMS: Long = get(DFS_EXPIRE_DIRS_TIMEOUT)
@@ -4827,6 +4829,15 @@ object CelebornConf extends Logging {
       .booleanConf
       .createWithDefault(false)
 
+  val CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED: ConfigEntry[Boolean] =
+    buildConf("celeborn.client.shuffle.dynamicResourceEnabled")
+      .categories("client")
+      .version("0.6.0")
+      .doc("When enabled, the ChangePartitionManager will obtain candidate 
workers from the availableWorkers pool “ +" +
+        "during heartbeats when worker resource change.")
+      .booleanConf
+      .createWithDefault(false)
+
   val CLIENT_PUSH_STAGE_END_TIMEOUT: ConfigEntry[Long] =
     buildConf("celeborn.client.push.stageEnd.timeout")
       .withAlternative("celeborn.push.stageEnd.timeout")
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index d0eb85b02..e5ea22fd4 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -413,6 +413,7 @@ object ControlMessages extends Logging {
       totalWritten: Long,
       fileCount: Long,
       needCheckedWorkerList: util.List[WorkerInfo],
+      needAvailableWorkers: Boolean,
       override var requestId: String = ZERO_UUID,
       shouldResponse: Boolean = false) extends MasterRequestMessage
 
@@ -421,6 +422,7 @@ object ControlMessages extends Logging {
       excludedWorkers: util.List[WorkerInfo],
       unknownWorkers: util.List[WorkerInfo],
       shuttingWorkers: util.List[WorkerInfo],
+      availableWorkers: util.List[WorkerInfo],
       registeredShuffles: util.List[Integer]) extends Message
 
   case class CheckQuota(userIdentifier: UserIdentifier) extends Message
@@ -809,6 +811,7 @@ object ControlMessages extends Logging {
           totalWritten,
           fileCount,
           needCheckedWorkerList,
+          needAvailableWorkers,
           requestId,
           shouldResponse) =>
       val payload = PbHeartbeatFromApplication.newBuilder()
@@ -818,6 +821,7 @@ object ControlMessages extends Logging {
         .setFileCount(fileCount)
         .addAllNeedCheckedWorkerList(needCheckedWorkerList.asScala.map(
           PbSerDeUtils.toPbWorkerInfo(_, true, true)).toList.asJava)
+        .setNeedAvailableWorkers(needAvailableWorkers)
         .setShouldResponse(shouldResponse)
         .build().toByteArray
       new TransportMessage(MessageType.HEARTBEAT_FROM_APPLICATION, payload)
@@ -827,6 +831,7 @@ object ControlMessages extends Logging {
           excludedWorkers,
           unknownWorkers,
           shuttingWorkers,
+          availableWorkers,
           registeredShuffles) =>
       val payload = PbHeartbeatFromApplicationResponse.newBuilder()
         .setStatus(statusCode.getValue)
@@ -836,6 +841,8 @@ object ControlMessages extends Logging {
           unknownWorkers.asScala.map(PbSerDeUtils.toPbWorkerInfo(_, true, 
true)).toList.asJava)
         .addAllShuttingWorkers(
           shuttingWorkers.asScala.map(PbSerDeUtils.toPbWorkerInfo(_, true, 
true)).toList.asJava)
+        .addAllAvailableWorkers(
+          availableWorkers.asScala.map(PbSerDeUtils.toPbWorkerInfo(_, true, 
true)).toList.asJava)
         .addAllRegisteredShuffles(registeredShuffles)
         .build().toByteArray
       new TransportMessage(MessageType.HEARTBEAT_FROM_APPLICATION_RESPONSE, 
payload)
@@ -1207,6 +1214,7 @@ object ControlMessages extends Logging {
           new util.ArrayList[WorkerInfo](
             pbHeartbeatFromApplication.getNeedCheckedWorkerListList.asScala
               .map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava),
+          pbHeartbeatFromApplication.getNeedAvailableWorkers,
           pbHeartbeatFromApplication.getRequestId,
           pbHeartbeatFromApplication.getShouldResponse)
 
@@ -1221,6 +1229,8 @@ object ControlMessages extends Logging {
             .map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava,
           pbHeartbeatFromApplicationResponse.getShuttingWorkersList.asScala
             .map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava,
+          pbHeartbeatFromApplicationResponse.getAvailableWorkersList.asScala
+            .map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava,
           pbHeartbeatFromApplicationResponse.getRegisteredShufflesList)
 
       case CHECK_QUOTA_VALUE =>
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index ce309a13f..492bd8133 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -95,6 +95,7 @@ license: |
 | celeborn.client.shuffle.compression.codec | LZ4 | false | The codec used to 
compress shuffle data. By default, Celeborn provides three codecs: `lz4`, 
`zstd`, `none`. `none` means that shuffle compression is disabled. Since Flink 
version 1.17, zstd is supported for Flink shuffle client. | 0.3.0 | 
celeborn.shuffle.compression.codec,remote-shuffle.job.compression.codec | 
 | celeborn.client.shuffle.compression.zstd.level | 1 | false | Compression 
level for Zstd compression codec, its value should be an integer between -5 and 
22. Increasing the compression level will result in better compression at the 
expense of more CPU and memory. | 0.3.0 | 
celeborn.shuffle.compression.zstd.level | 
 | celeborn.client.shuffle.decompression.lz4.xxhash.instance | 
&lt;undefined&gt; | false | Decompression XXHash instance for Lz4. Available 
options: JNI, JAVASAFE, JAVAUNSAFE. | 0.3.2 |  | 
+| celeborn.client.shuffle.dynamicResourceEnabled | false | false | When 
enabled, the ChangePartitionManager will obtain candidate workers from the 
availableWorkers pool “ +during heartbeats when worker resource change. | 0.6.0 
|  | 
 | celeborn.client.shuffle.expired.checkInterval | 60s | false | Interval for 
client to check expired shuffles. | 0.3.0 | 
celeborn.shuffle.expired.checkInterval | 
 | celeborn.client.shuffle.manager.port | 0 | false | Port used by the 
LifecycleManager on the Driver. | 0.3.0 | celeborn.shuffle.manager.port | 
 | celeborn.client.shuffle.mapPartition.split.enabled | false | false | whether 
to enable shuffle partition split. Currently, this only applies to 
MapPartition. | 0.3.1 |  | 
diff --git 
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala 
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
index 86baaa921..bec47890e 100644
--- 
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
+++ 
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
@@ -407,6 +407,7 @@ private[celeborn] class Master(
           totalWritten,
           fileCount,
           needCheckedWorkerList,
+          needAvailableWorkers,
           requestId,
           shouldResponse) =>
       logDebug(s"Received heartbeat from app $appId")
@@ -419,6 +420,7 @@ private[celeborn] class Master(
           totalWritten,
           fileCount,
           needCheckedWorkerList,
+          needAvailableWorkers,
           requestId,
           shouldResponse))
 
@@ -1113,6 +1115,7 @@ private[celeborn] class Master(
       totalWritten: Long,
       fileCount: Long,
       needCheckedWorkerList: util.List[WorkerInfo],
+      needAvailableWorkers: Boolean,
       requestId: String,
       shouldResponse: Boolean): Unit = {
     statusSystem.handleAppHeartbeat(
@@ -1126,6 +1129,12 @@ private[celeborn] class Master(
     if (shouldResponse) {
       // UserResourceConsumption and DiskInfo are eliminated from WorkerInfo
       // during serialization of HeartbeatFromApplicationResponse
+      var availableWorksSentToClient = new util.ArrayList[WorkerInfo]()
+      if (needAvailableWorkers) {
+        availableWorksSentToClient = new util.ArrayList[WorkerInfo](
+          statusSystem.workers.asScala.filter(worker =>
+            statusSystem.isWorkerAvailable(worker)).asJava)
+      }
       var appRelatedShuffles =
         statusSystem.registeredAppAndShuffles.getOrDefault(appId, 
Collections.emptySet())
       context.reply(HeartbeatFromApplicationResponse(
@@ -1135,6 +1144,7 @@ private[celeborn] class Master(
         needCheckedWorkerList,
         new util.ArrayList[WorkerInfo](
           (statusSystem.shutdownWorkers.asScala ++ 
statusSystem.decommissionWorkers.asScala).asJava),
+        availableWorksSentToClient,
         new util.ArrayList(appRelatedShuffles)))
     } else {
       context.reply(OneWayMessageResponse)
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala
new file mode 100644
index 000000000..bee3b61ec
--- /dev/null
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.client
+
+import java.util
+
+import scala.collection.JavaConverters.mapAsScalaMapConverter
+
+import org.apache.celeborn.client.{ChangePartitionManager, 
ChangePartitionRequest, LifecycleManager, WithShuffleClientSuite}
+import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, 
WorkerInfo}
+import org.apache.celeborn.common.protocol.message.StatusCode
+import org.apache.celeborn.common.util.{CelebornExitKind, JavaUtils}
+import org.apache.celeborn.service.deploy.MiniClusterFeature
+
+class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite
+  with MiniClusterFeature {
+  celebornConf
+    .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "false")
+    .set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "256K")
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    val testConf = Map(
+      s"${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}" -> "3")
+    val (master, _) = setupMiniClusterWithRandomPorts(testConf, testConf, 
workerNum = 1)
+    celebornConf.set(
+      CelebornConf.MASTER_ENDPOINTS.key,
+      master.conf.get(CelebornConf.MASTER_ENDPOINTS.key))
+  }
+
+  test("test changePartition with available workers") {
+    val shuffleId = nextShuffleId
+    val conf = celebornConf.clone
+    conf.set(CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key, "3")
+      .set(CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED.key, "true")
+      .set(CelebornConf.CLIENT_BATCH_HANDLE_CHANGE_PARTITION_ENABLED.key, 
"false")
+
+    val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf)
+    val changePartitionManager: ChangePartitionManager =
+      new ChangePartitionManager(conf, lifecycleManager)
+    val ids = new util.ArrayList[Integer](10)
+    0 until 10 foreach {
+      ids.add(_)
+    }
+    val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 
ids)
+    assert(res.status == StatusCode.SUCCESS)
+    assert(res.workerResource.keySet().size() == 1)
+
+    lifecycleManager.setupEndpoints(
+      res.workerResource.keySet(),
+      shuffleId,
+      new ShuffleFailedWorkers())
+
+    val reserveSlotsSuccess = lifecycleManager.reserveSlotsWithRetry(
+      shuffleId,
+      new util.HashSet(res.workerResource.keySet()),
+      res.workerResource,
+      updateEpoch = false)
+
+    val slots = res.workerResource
+    val candidatesWorkers = new util.HashSet(slots.keySet())
+    if (reserveSlotsSuccess) {
+      val allocatedWorkers =
+        JavaUtils.newConcurrentHashMap[WorkerInfo, 
ShufflePartitionLocationInfo]()
+      res.workerResource.asScala.foreach {
+        case (workerInfo, (primaryLocations, replicaLocations)) =>
+          val partitionLocationInfo = new ShufflePartitionLocationInfo()
+          partitionLocationInfo.addPrimaryPartitions(primaryLocations)
+          partitionLocationInfo.addReplicaPartitions(replicaLocations)
+          allocatedWorkers.put(workerInfo, partitionLocationInfo)
+      }
+      lifecycleManager.shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
+      
lifecycleManager.workerStatusTracker.addWorkersWithEndpoint(candidatesWorkers)
+    }
+    assert(lifecycleManager.workerSnapshots(shuffleId).size() == 1)
+    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() 
== 1)
+    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size()
 == 0)
+
+    // total workerNum is 1 + 2 = 3 now
+    setUpWorkers(workerConfForAdding, 2)
+    // longer than APPLICATION_HEARTBEAT_INTERVAL 10s
+    Thread.sleep(15000)
+    assert(workerInfos.size == 3)
+    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() 
== 1)
+    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size()
 == 2)
+
+    0 until 10 foreach { partitionId: Int =>
+      val req = ChangePartitionRequest(
+        null,
+        shuffleId,
+        partitionId,
+        -1,
+        null,
+        None)
+      changePartitionManager.changePartitionRequests.computeIfAbsent(
+        shuffleId,
+        changePartitionManager.rpcContextRegisterFunc)
+      changePartitionManager.handleRequestPartitions(
+        shuffleId,
+        Array(req),
+        lifecycleManager.commitManager.isSegmentGranularityVisible(shuffleId))
+    }
+    assert(
+      
lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size() + 
lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() == 3)
+
+    assert(lifecycleManager.workerSnapshots(shuffleId).size() > 1)
+    assert(
+      lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() 
> 1)
+
+    // shut down workers test
+    val workerInfoList = workerInfos.toList
+    0 until 2 foreach { index =>
+      val (worker, _) = workerInfoList(index)
+      worker.stop(CelebornExitKind.EXIT_IMMEDIATELY)
+      worker.rpcEnv.shutdown()
+      // Workers in miniClusterFeature wont update status with master through 
heartbeat.
+      // So update status manually.
+      masterInfo._1.statusSystem.excludedWorkers.add(worker.workerInfo)
+      workerInfos.remove(worker)
+    }
+
+    Thread.sleep(15000)
+    assert(
+      lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() 
+ lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size() 
== 1)
+
+    lifecycleManager.stop()
+  }
+
+  test("test changePartition without available workers") {
+    val shuffleId = nextShuffleId
+    val conf = celebornConf.clone
+    conf.set(CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key, "3")
+      .set(CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED.key, "false")
+      .set(CelebornConf.CLIENT_BATCH_HANDLE_CHANGE_PARTITION_ENABLED.key, 
"false")
+
+    val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf)
+    val changePartitionManager: ChangePartitionManager =
+      new ChangePartitionManager(conf, lifecycleManager)
+    val ids = new util.ArrayList[Integer](10)
+    0 until 10 foreach {
+      ids.add(_)
+    }
+    val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 
ids)
+    assert(res.status == StatusCode.SUCCESS)
+
+    // workerNum is 1 (after 1 add 2 and stop 2)
+    val workerNum = res.workerResource.keySet().size()
+    assert(workerNum == 1)
+
+    lifecycleManager.setupEndpoints(
+      res.workerResource.keySet(),
+      shuffleId,
+      new ShuffleFailedWorkers())
+
+    val reserveSlotsSuccess = lifecycleManager.reserveSlotsWithRetry(
+      shuffleId,
+      new util.HashSet(res.workerResource.keySet()),
+      res.workerResource,
+      updateEpoch = false)
+
+    val slots = res.workerResource
+    val candidatesWorkers = new util.HashSet(slots.keySet())
+    if (reserveSlotsSuccess) {
+      val allocatedWorkers =
+        JavaUtils.newConcurrentHashMap[WorkerInfo, 
ShufflePartitionLocationInfo]()
+      res.workerResource.asScala.foreach {
+        case (workerInfo, (primaryLocations, replicaLocations)) =>
+          val partitionLocationInfo = new ShufflePartitionLocationInfo()
+          partitionLocationInfo.addPrimaryPartitions(primaryLocations)
+          partitionLocationInfo.addReplicaPartitions(replicaLocations)
+          allocatedWorkers.put(workerInfo, partitionLocationInfo)
+      }
+      lifecycleManager.shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
+      
lifecycleManager.workerStatusTracker.addWorkersWithEndpoint(candidatesWorkers)
+    }
+    assert(lifecycleManager.workerSnapshots(shuffleId).size() == workerNum)
+    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() 
== workerNum)
+
+    // total workerNum is 1 + 2 = 3 now
+    setUpWorkers(workerConfForAdding, 2)
+    // longer than APPLICATION_HEARTBEAT_INTERVAL 10s
+    Thread.sleep(15000)
+    assert(workerInfos.size == 3)
+    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() 
== workerNum)
+    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size()
 == 0)
+
+    0 until 10 foreach { partitionId: Int =>
+      val req = ChangePartitionRequest(
+        null,
+        shuffleId,
+        partitionId,
+        -1,
+        null,
+        None)
+      changePartitionManager.changePartitionRequests.computeIfAbsent(
+        shuffleId,
+        changePartitionManager.rpcContextRegisterFunc)
+      changePartitionManager.handleRequestPartitions(
+        shuffleId,
+        Array(req),
+        lifecycleManager.commitManager.isSegmentGranularityVisible(shuffleId))
+    }
+    logInfo(s"reallocated worker num: ${res.workerResource.keySet().size()}; 
workerInfo: ${res.workerResource.keySet()}")
+    assert(lifecycleManager.workerSnapshots(shuffleId).size() == workerNum)
+    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() 
== workerNum)
+    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size()
 == 0)
+
+    lifecycleManager.stop()
+  }
+
+  override def afterAll(): Unit = {
+    logInfo("all test complete , stop celeborn mini cluster")
+    shutdownMiniCluster()
+  }
+}
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
index 90192e283..1f7769140 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
@@ -58,7 +58,10 @@ class LifecycleManagerCommitFilesSuite extends 
WithShuffleClientSuite with MiniC
     assert(res.status == StatusCode.SUCCESS)
     assert(res.workerResource.keySet().size() == 3)
 
-    lifecycleManager.setupEndpoints(res.workerResource, shuffleId, new 
ShuffleFailedWorkers())
+    lifecycleManager.setupEndpoints(
+      res.workerResource.keySet(),
+      shuffleId,
+      new ShuffleFailedWorkers())
 
     lifecycleManager.reserveSlotsWithRetry(
       shuffleId,
@@ -108,7 +111,10 @@ class LifecycleManagerCommitFilesSuite extends 
WithShuffleClientSuite with MiniC
     assert(res.status == StatusCode.SUCCESS)
     assert(res.workerResource.keySet().size() == 3)
 
-    lifecycleManager.setupEndpoints(res.workerResource, shuffleId, new 
ShuffleFailedWorkers())
+    lifecycleManager.setupEndpoints(
+      res.workerResource.keySet(),
+      shuffleId,
+      new ShuffleFailedWorkers())
 
     lifecycleManager.reserveSlotsWithRetry(
       shuffleId,
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala
index 1fc639d60..e83268202 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala
@@ -56,7 +56,10 @@ class LifecycleManagerDestroySlotsSuite extends 
WithShuffleClientSuite with Mini
     assert(res.status == StatusCode.SUCCESS)
     assert(res.workerResource.keySet().size() == 3)
 
-    lifecycleManager.setupEndpoints(res.workerResource, shuffleId, new 
ShuffleFailedWorkers())
+    lifecycleManager.setupEndpoints(
+      res.workerResource.keySet(),
+      shuffleId,
+      new ShuffleFailedWorkers())
 
     lifecycleManager.reserveSlotsWithRetry(
       shuffleId,
@@ -95,7 +98,10 @@ class LifecycleManagerDestroySlotsSuite extends 
WithShuffleClientSuite with Mini
     assert(res.status == StatusCode.SUCCESS)
     assert(res.workerResource.keySet().size() == 3)
 
-    lifecycleManager.setupEndpoints(res.workerResource, shuffleId, new 
ShuffleFailedWorkers())
+    lifecycleManager.setupEndpoints(
+      res.workerResource.keySet(),
+      shuffleId,
+      new ShuffleFailedWorkers())
 
     lifecycleManager.reserveSlotsWithRetry(
       shuffleId,
@@ -134,7 +140,10 @@ class LifecycleManagerDestroySlotsSuite extends 
WithShuffleClientSuite with Mini
     assert(res.status == StatusCode.SUCCESS)
     assert(res.workerResource.keySet().size() == 3)
 
-    lifecycleManager.setupEndpoints(res.workerResource, shuffleId, new 
ShuffleFailedWorkers())
+    lifecycleManager.setupEndpoints(
+      res.workerResource.keySet(),
+      shuffleId,
+      new ShuffleFailedWorkers())
 
     lifecycleManager.reserveSlotsWithRetry(
       shuffleId,
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSetupEndpointSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSetupEndpointSuite.scala
index 8fee9cbb9..9ad455756 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSetupEndpointSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSetupEndpointSuite.scala
@@ -52,7 +52,7 @@ class LifecycleManagerSetupEndpointSuite extends 
WithShuffleClientSuite with Min
     assert(res.workerResource.keySet().size() == 3)
 
     val connectFailedWorkers = new ShuffleFailedWorkers()
-    lifecycleManager.setupEndpoints(res.workerResource, 0, 
connectFailedWorkers)
+    lifecycleManager.setupEndpoints(res.workerResource.keySet(), 0, 
connectFailedWorkers)
     assert(connectFailedWorkers.isEmpty)
 
     lifecycleManager.stop()
@@ -73,7 +73,7 @@ class LifecycleManagerSetupEndpointSuite extends 
WithShuffleClientSuite with Min
     firstWorker.rpcEnv.shutdown()
 
     val connectFailedWorkers = new ShuffleFailedWorkers()
-    lifecycleManager.setupEndpoints(res.workerResource, 0, 
connectFailedWorkers)
+    lifecycleManager.setupEndpoints(res.workerResource.keySet(), 0, 
connectFailedWorkers)
     assert(connectFailedWorkers.size() == 1)
     assert(connectFailedWorkers.keySet().asScala.head == 
firstWorker.workerInfo)
 
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
index ea6eed601..4e3b7e6ff 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
@@ -32,21 +32,43 @@ class RetryReviveTest extends AnyFunSuite
 
   override def beforeAll(): Unit = {
     logInfo("test initialized , setup celeborn mini cluster")
-    setupMiniClusterWithRandomPorts()
   }
 
-  override def beforeEach(): Unit = {
-    ShuffleClient.reset()
-  }
+  override def beforeEach(): Unit = {}
 
   override def afterEach(): Unit = {
     System.gc()
   }
 
   test("celeborn spark integration test - retry revive as configured times") {
+    setupMiniClusterWithRandomPorts()
+    ShuffleClient.reset()
+    val sparkConf = new SparkConf()
+      .set(s"spark.${CelebornConf.TEST_CLIENT_RETRY_REVIVE.key}", "true")
+      .set(s"spark.${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}", "3")
+      .setAppName("celeborn-demo").setMaster("local[2]")
+    val ss = SparkSession.builder()
+      .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+      .getOrCreate()
+    val result = ss.sparkContext.parallelize(1 to 1000, 2)
+      .map { i => (i, Range(1, 1000).mkString(",")) }.groupByKey(4).collect()
+    assert(result.size == 1000)
+    ss.stop()
+  }
+
+  test(
+    "celeborn spark integration test - e2e test retry revive with available 
workers from heartbeat") {
+    val testConf = Map(
+      s"${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}" -> "3",
+      s"${CelebornConf.MASTER_SLOT_ASSIGN_EXTRA_SLOTS.key}" -> "0")
+    setupMiniClusterWithRandomPorts(testConf)
+    ShuffleClient.reset()
     val sparkConf = new SparkConf()
       .set(s"spark.${CelebornConf.TEST_CLIENT_RETRY_REVIVE.key}", "true")
       .set(s"spark.${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}", "3")
+      .set(s"spark.${CelebornConf.CLIENT_SLOT_ASSIGN_MAX_WORKERS.key}", "1")
+      
.set(s"spark.${CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED.key}", 
"true")
+      .set(s"spark.${CelebornConf.MASTER_SLOT_ASSIGN_EXTRA_SLOTS.key}", "0")
       .setAppName("celeborn-demo").setMaster("local[2]")
     val ss = SparkSession.builder()
       .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
index fb6bd67f6..7c1280f13 100644
--- 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
@@ -37,6 +37,7 @@ trait MiniClusterFeature extends Logging {
 
   var masterInfo: (Master, Thread) = _
   val workerInfos = new mutable.HashMap[Worker, Thread]()
+  var workerConfForAdding: Map[String, String] = _
 
   class RunnerWrap[T](code: => T) extends Thread {
 
@@ -71,6 +72,7 @@ trait MiniClusterFeature extends Logging {
           workerConf
         logInfo(
           s"generated configuration. Master conf = $finalMasterConf, worker 
conf = $finalWorkerConf")
+        workerConfForAdding = finalWorkerConf
         val (m, w) =
           setUpMiniCluster(masterConf = finalMasterConf, workerConf = 
finalWorkerConf, workerNum)
         master = m
@@ -148,10 +150,7 @@ trait MiniClusterFeature extends Logging {
     }
   }
 
-  private def setUpMiniCluster(
-      masterConf: Map[String, String] = null,
-      workerConf: Map[String, String] = null,
-      workerNum: Int = 3): (Master, collection.Set[Worker]) = {
+  def setUpMaster(masterConf: Map[String, String] = null): Master = {
     val timeout = 30000
     val master = createMaster(masterConf)
     val masterStartedSignal = Array(false)
@@ -176,7 +175,13 @@ trait MiniClusterFeature extends Logging {
         throw new BindException("cannot start master rpc endpoint")
       }
     }
+    master
+  }
 
+  def setUpWorkers(
+      workerConf: Map[String, String] = null,
+      workerNum: Int = 3): collection.Set[Worker] = {
+    val timeout = 30000
     val workers = new Array[Worker](workerNum)
     val flagUpdateLock = new ReentrantLock()
     val threads = (1 to workerNum).map { i =>
@@ -239,7 +244,14 @@ trait MiniClusterFeature extends Logging {
           }
       }
     }
-    (master, workerInfos.keySet)
+    workerInfos.keySet
+  }
+
+  private def setUpMiniCluster(
+      masterConf: Map[String, String] = null,
+      workerConf: Map[String, String] = null,
+      workerNum: Int = 3): (Master, collection.Set[Worker]) = {
+    (setUpMaster(masterConf), setUpWorkers(workerConf, workerNum))
   }
 
   def shutdownMiniCluster(): Unit = {

Reply via email to