This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 64f201dd8 [CELEBORN-1636][FOLLOWUP] Dynamic resources will only be 
utilized in case of candidates shortages
64f201dd8 is described below

commit 64f201dd83d025904c035b75e9907adced0968f6
Author: szt <[email protected]>
AuthorDate: Tue Nov 5 18:10:01 2024 +0800

    [CELEBORN-1636][FOLLOWUP] Dynamic resources will only be utilized in case 
of candidates shortages
    
    ### What changes were proposed in this pull request?
    Follow up of [https://github.com/apache/celeborn/pull/2835]
    Only use dynamic resources when candidates are not enough.
    And change the way geting availableWorkers form heartbeat to requestSlots 
RPC to avoid the burden of heartbeat.
    
    ### Why are the changes needed?
    No
    
    ### Does this PR introduce _any_ user-facing change?
    Add another configuration.
    
    ### How was this patch tested?
    UT
    
    Closes #2852 from zaynt4606/clb1636-flu2.
    
    Authored-by: szt <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../celeborn/client/ApplicationHeartbeater.scala   |   3 -
 .../celeborn/client/ChangePartitionManager.scala   | 103 ++++++----
 .../apache/celeborn/client/LifecycleManager.scala  |   1 -
 .../celeborn/client/WorkerStatusTracker.scala      |  50 +----
 .../celeborn/client/WorkerStatusTrackerSuite.scala | 110 +----------
 common/src/main/proto/TransportMessages.proto      |   4 +-
 .../org/apache/celeborn/common/CelebornConf.scala  |  12 ++
 .../common/protocol/message/ControlMessages.scala  |  10 -
 docs/configuration/client.md                       |   1 +
 .../celeborn/service/deploy/master/Master.scala    |  12 +-
 .../ChangePartitionManagerUpdateWorkersSuite.scala | 219 +++++++++++++++++----
 .../celeborn/tests/spark/RetryReviveTest.scala     |   4 +-
 12 files changed, 272 insertions(+), 257 deletions(-)

diff --git 
a/client/src/main/scala/org/apache/celeborn/client/ApplicationHeartbeater.scala 
b/client/src/main/scala/org/apache/celeborn/client/ApplicationHeartbeater.scala
index b73582745..b558df596 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/ApplicationHeartbeater.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/ApplicationHeartbeater.scala
@@ -48,7 +48,6 @@ class ApplicationHeartbeater(
   // Use independent app heartbeat threads to avoid being blocked by other 
operations.
   private val appHeartbeatIntervalMs = conf.appHeartbeatIntervalMs
   private val applicationUnregisterEnabled = conf.applicationUnregisterEnabled
-  private val clientShuffleDynamicResourceEnabled = 
conf.clientShuffleDynamicResourceEnabled
   private val appHeartbeatHandlerThread =
     ThreadUtils.newDaemonSingleThreadScheduledExecutor(
       "celeborn-client-lifecycle-manager-app-heartbeater")
@@ -71,7 +70,6 @@ class ApplicationHeartbeater(
                 tmpTotalWritten,
                 tmpTotalFileCount,
                 workerStatusTracker.getNeedCheckedWorkers().toList.asJava,
-                clientShuffleDynamicResourceEnabled,
                 ZERO_UUID,
                 true)
             val response = requestHeartbeat(appHeartbeat)
@@ -134,7 +132,6 @@ class ApplicationHeartbeater(
           List.empty.asJava,
           List.empty.asJava,
           List.empty.asJava,
-          List.empty.asJava,
           CheckQuotaResponse(isAvailable = true, ""))
     }
   }
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
index 771b51151..3daeac5ab 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
@@ -76,7 +76,8 @@ class ChangePartitionManager(
 
   private val testRetryRevive = conf.testRetryRevive
 
-  private val clientShuffleDynamicResourceEnabled = 
conf.clientShuffleDynamicResourceEnabled
+  private val dynamicResourceEnabled = conf.clientShuffleDynamicResourceEnabled
+  private val dynamicResourceUnavailableFactor = 
conf.clientShuffleDynamicResourceFactor
 
   def start(): Unit = {
     batchHandleChangePartition = batchHandleChangePartitionSchedulerThread.map 
{
@@ -291,42 +292,69 @@ class ChangePartitionManager(
     }
 
     val candidates = new util.HashSet[WorkerInfo]()
-    if (clientShuffleDynamicResourceEnabled) {
-      // availableWorkers wont filter excludedWorkers in heartBeat So have to 
do filtering.
-      candidates.addAll(lifecycleManager
-        .workerStatusTracker
-        .availableWorkersWithEndpoint
-        .values()
+    val newlyRequestedLocations = new WorkerResource()
+
+    val snapshotCandidates =
+      lifecycleManager
+        .workerSnapshots(shuffleId)
+        .keySet()
         .asScala
-        .toSet
         .filter(lifecycleManager.workerStatusTracker.workerAvailable)
-        .asJava)
-
-      // SetupEndpoint for those availableWorkers without endpoint
-      val workersRequireEndpoints = new util.HashSet[WorkerInfo](
-        
lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.asScala.filter(
-          lifecycleManager.workerStatusTracker.workerAvailable).asJava)
-      val connectFailedWorkers = new ShuffleFailedWorkers()
-      lifecycleManager.setupEndpoints(
-        workersRequireEndpoints,
-        shuffleId,
-        connectFailedWorkers)
-      
workersRequireEndpoints.removeAll(connectFailedWorkers.asScala.keys.toList.asJava)
-      candidates.addAll(workersRequireEndpoints)
-
-      // Update worker status
-      
lifecycleManager.workerStatusTracker.addWorkersWithEndpoint(workersRequireEndpoints)
-      
lifecycleManager.workerStatusTracker.recordWorkerFailure(connectFailedWorkers)
-      
lifecycleManager.workerStatusTracker.removeFromExcludedWorkers(candidates)
-    } else {
-      val snapshotCandidates =
-        lifecycleManager
-          .workerSnapshots(shuffleId)
-          .keySet()
-          .asScala
-          .filter(lifecycleManager.workerStatusTracker.workerAvailable)
-          .asJava
-      candidates.addAll(snapshotCandidates)
+        .asJava
+    candidates.addAll(snapshotCandidates)
+
+    if (dynamicResourceEnabled) {
+      val shuffleAllocatedWorkers = 
lifecycleManager.workerSnapshots(shuffleId).size()
+      val unavailableWorkerRatio = 1 - (snapshotCandidates.size * 1.0 / 
shuffleAllocatedWorkers)
+      if (candidates.size < 1 || (pushReplicateEnabled && candidates.size < 2)
+        || (unavailableWorkerRatio >= dynamicResourceUnavailableFactor)) {
+
+        // get new available workers for the request partition ids
+        val partitionIds = new util.ArrayList[Integer](
+          
changePartitions.map(_.partitionId).map(Integer.valueOf).toList.asJava)
+        // The partition id value is not important here because we're just 
trying to get the workers to use
+        val requestSlotsRes =
+          lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 
partitionIds)
+
+        requestSlotsRes.status match {
+          case StatusCode.REQUEST_FAILED =>
+            logInfo(s"ChangePartition requestSlots RPC request failed for 
$shuffleId!")
+          case StatusCode.SLOT_NOT_AVAILABLE =>
+            logInfo(s"ChangePartition requestSlots for $shuffleId failed, have 
no available slots.")
+          case StatusCode.SUCCESS =>
+            logDebug(
+              s"ChangePartition requestSlots request for workers Success! 
shuffleId: $shuffleId availableWorkers Info: 
${requestSlotsRes.workerResource.keySet()}")
+          case StatusCode.WORKER_EXCLUDED =>
+            logInfo(s"ChangePartition requestSlots request for workers for 
$shuffleId failed due to all workers be excluded!")
+          case _ => // won't happen
+            throw new UnsupportedOperationException()
+        }
+
+        if (requestSlotsRes.status.equals(StatusCode.SUCCESS)) {
+          requestSlotsRes.workerResource.keySet().asScala.foreach { 
workerInfo: WorkerInfo =>
+            newlyRequestedLocations.computeIfAbsent(workerInfo, 
lifecycleManager.newLocationFunc)
+          }
+
+          // SetupEndpoint for new Workers
+          val workersRequireEndpoints = new util.HashSet[WorkerInfo](
+            requestSlotsRes.workerResource.keySet()
+              .asScala
+              .filter(lifecycleManager.workerStatusTracker.workerAvailable)
+              .asJava)
+
+          val connectFailedWorkers = new ShuffleFailedWorkers()
+          lifecycleManager.setupEndpoints(
+            workersRequireEndpoints,
+            shuffleId,
+            connectFailedWorkers)
+          
workersRequireEndpoints.removeAll(connectFailedWorkers.asScala.keys.toList.asJava)
+          candidates.addAll(workersRequireEndpoints)
+
+          // Update worker status
+          
lifecycleManager.workerStatusTracker.recordWorkerFailure(connectFailedWorkers)
+          
lifecycleManager.workerStatusTracker.removeFromExcludedWorkers(candidates)
+        }
+      }
     }
 
     if (candidates.size < 1 || (pushReplicateEnabled && candidates.size < 2)) {
@@ -351,7 +379,10 @@ class ChangePartitionManager(
       return
     }
 
-    val newPrimaryLocations = newlyAllocatedLocations.asScala.flatMap {
+    // newlyRequestedLocations is empty if dynamicResourceEnabled is false
+    newlyRequestedLocations.putAll(newlyAllocatedLocations)
+
+    val newPrimaryLocations = newlyRequestedLocations.asScala.flatMap {
       case (workInfo, (primaryLocations, replicaLocations)) =>
         // Add all re-allocated slots to worker snapshots.
         val partitionLocationInfo = 
lifecycleManager.workerSnapshots(shuffleId).computeIfAbsent(
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 92ccc66e2..afab1a56e 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -713,7 +713,6 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
         allocatedWorkers.put(workerInfo, partitionLocationInfo)
       }
       shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
-      workerStatusTracker.addWorkersWithEndpoint(candidatesWorkers)
       registeredShuffle.add(shuffleId)
       commitManager.registerShuffle(
         shuffleId,
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/WorkerStatusTracker.scala 
b/client/src/main/scala/org/apache/celeborn/client/WorkerStatusTracker.scala
index 088341813..65dbe1e9a 100644
--- a/client/src/main/scala/org/apache/celeborn/client/WorkerStatusTracker.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/WorkerStatusTracker.scala
@@ -30,27 +30,17 @@ import org.apache.celeborn.common.meta.WorkerInfo
 import org.apache.celeborn.common.protocol.PartitionLocation
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.HeartbeatFromApplicationResponse
 import org.apache.celeborn.common.protocol.message.StatusCode
-import org.apache.celeborn.common.util.{JavaUtils, Utils}
+import org.apache.celeborn.common.util.Utils
 
 class WorkerStatusTracker(
     conf: CelebornConf,
     lifecycleManager: LifecycleManager) extends Logging {
   private val excludedWorkerExpireTimeout = 
conf.clientExcludedWorkerExpireTimeout
   private val workerStatusListeners = 
ConcurrentHashMap.newKeySet[WorkerStatusListener]()
-  private val clientShuffleDynamicResourceEnabled = 
conf.clientShuffleDynamicResourceEnabled
 
   val excludedWorkers = new ShuffleFailedWorkers()
   val shuttingWorkers: JSet[WorkerInfo] = new JHashSet[WorkerInfo]()
 
-  // Workers that have already set an endpoint can skip the setupEndpoint 
process in changePartition when reviving
-  // key: WorkerInfo.toUniqueId value: WorkerInfo
-  val availableWorkersWithEndpoint: ConcurrentHashMap[String, WorkerInfo] =
-    JavaUtils.newConcurrentHashMap[String, WorkerInfo]()
-
-  // Workers that may be available but have not been used(without endpoint)
-  // availableWorkersWithoutEndpoint is empty until 
appHeartbeatWithAvailableWorkers set to true
-  val availableWorkersWithoutEndpoint = 
ConcurrentHashMap.newKeySet[WorkerInfo]()
-
   def registerWorkerStatusListener(workerStatusListener: 
WorkerStatusListener): Unit = {
     workerStatusListeners.add(workerStatusListener)
   }
@@ -141,16 +131,13 @@ class WorkerStatusTracker(
       failedWorkers.asScala.foreach {
         case (worker, (StatusCode.WORKER_SHUTDOWN, _)) =>
           shuttingWorkers.add(worker)
-          removeFromAvailableWorkers(worker)
         case (worker, (statusCode, registerTime)) if 
!excludedWorkers.containsKey(worker) =>
           excludedWorkers.put(worker, (statusCode, registerTime))
-          removeFromAvailableWorkers(worker)
         case (worker, (statusCode, _))
             if statusCode == StatusCode.NO_AVAILABLE_WORKING_DIR ||
               statusCode == StatusCode.RESERVE_SLOTS_FAILED ||
               statusCode == StatusCode.WORKER_UNKNOWN =>
           excludedWorkers.put(worker, (statusCode, 
excludedWorkers.get(worker)._2))
-          removeFromAvailableWorkers(worker)
         case _ => // Not cover
       }
     }
@@ -160,22 +147,10 @@ class WorkerStatusTracker(
     excludedWorkers.keySet.removeAll(workers)
   }
 
-  private def removeFromAvailableWorkers(worker: WorkerInfo): Unit = {
-    availableWorkersWithEndpoint.remove(worker.toUniqueId())
-    availableWorkersWithoutEndpoint.remove(worker)
-  }
-
-  def addWorkersWithEndpoint(workers: JHashSet[WorkerInfo]): Unit = {
-    availableWorkersWithoutEndpoint.removeAll(workers)
-    workers.asScala.foreach { workerInfo =>
-      availableWorkersWithEndpoint.put(workerInfo.toUniqueId(), workerInfo)
-    }
-  }
-
   def handleHeartbeatResponse(res: HeartbeatFromApplicationResponse): Unit = {
     if (res.statusCode == StatusCode.SUCCESS) {
       logDebug(s"Received Worker status from Primary, excluded workers: 
${res.excludedWorkers} " +
-        s"unknown workers: ${res.unknownWorkers}, shutdown workers: 
${res.shuttingWorkers}, available workers from heartbeat: 
${res.availableWorkers}")
+        s"unknown workers: ${res.unknownWorkers}, shutdown workers: 
${res.shuttingWorkers}")
       val current = System.currentTimeMillis()
       var statusChanged = false
 
@@ -217,27 +192,6 @@ class WorkerStatusTracker(
       val retainShuttingWorkersResult = 
shuttingWorkers.retainAll(res.shuttingWorkers)
       val addShuttingWorkersResult = 
shuttingWorkers.addAll(res.shuttingWorkers)
 
-      if (clientShuffleDynamicResourceEnabled) {
-        // AvailableWorkers filter Client excludedWorkers and shuttingWorkers.
-        // AvailableWorkers already filtered res.excludedWorkers and 
res.shuttingWorkers.
-        val resAvailableWorkers: JSet[WorkerInfo] = new 
JHashSet[WorkerInfo](res.availableWorkers)
-        // update availableWorkers
-        // availableWorkers wont filter excludedWorkers.
-        // So before using them we hava to filter excludedWorkers.
-        availableWorkersWithoutEndpoint.retainAll(resAvailableWorkers)
-        availableWorkersWithEndpoint.keySet().retainAll(
-          resAvailableWorkers.asScala.map(_.toUniqueId()).asJava)
-        resAvailableWorkers.asScala.foreach { workerInfo: WorkerInfo =>
-          if 
(!availableWorkersWithEndpoint.keySet.contains(workerInfo.toUniqueId())) {
-            availableWorkersWithoutEndpoint.add(workerInfo)
-          } else {
-            if (availableWorkersWithoutEndpoint.contains(workerInfo)) {
-              availableWorkersWithoutEndpoint.remove(workerInfo)
-            }
-          }
-        }
-      }
-
       statusChanged =
         statusChanged || retainShuttingWorkersResult || 
addShuttingWorkersResult
       // Always trigger commit files for shutting down workers from 
HeartbeatFromApplicationResponse
diff --git 
a/client/src/test/scala/org/apache/celeborn/client/WorkerStatusTrackerSuite.scala
 
b/client/src/test/scala/org/apache/celeborn/client/WorkerStatusTrackerSuite.scala
index 66219f957..efaeb8439 100644
--- 
a/client/src/test/scala/org/apache/celeborn/client/WorkerStatusTrackerSuite.scala
+++ 
b/client/src/test/scala/org/apache/celeborn/client/WorkerStatusTrackerSuite.scala
@@ -40,7 +40,7 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
     statusTracker.excludedWorkers.put(mock("host2"), 
(StatusCode.WORKER_SHUTDOWN, registerTime))
 
     // test reserve (only statusCode list in handleHeartbeatResponse)
-    val empty = buildResponse(Array.empty, Array.empty, Array.empty, 
Array.empty)
+    val empty = buildResponse(Array.empty, Array.empty, Array.empty)
     statusTracker.handleHeartbeatResponse(empty)
 
     // only reserve host1
@@ -51,7 +51,7 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
 
     // add shutdown/excluded worker
     val response1 =
-      buildResponse(Array("host0"), Array("host1", "host3"), Array("host4"), 
Array.empty)
+      buildResponse(Array("host0"), Array("host1", "host3"), Array("host4"))
     statusTracker.handleHeartbeatResponse(response1)
 
     // test keep Unknown register time
@@ -66,7 +66,7 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
     Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))
 
     // test re heartbeat with shutdown workers
-    val response2 = buildResponse(Array.empty, Array.empty, Array("host4"), 
Array.empty)
+    val response2 = buildResponse(Array.empty, Array.empty, Array("host4"))
     statusTracker.handleHeartbeatResponse(response2)
     
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
     Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))
@@ -79,124 +79,24 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
 
     // test register time elapsed
     Thread.sleep(3000)
-    val response3 = buildResponse(Array.empty, Array("host5", "host6"), 
Array.empty, Array.empty)
+    val response3 = buildResponse(Array.empty, Array("host5", "host6"), 
Array.empty)
     statusTracker.handleHeartbeatResponse(response3)
     Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
     
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host1")))
-
-    // test available workers
-    Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(), 
0)
-    val response4 = buildResponse(
-      Array.empty,
-      Array.empty,
-      Array.empty,
-      Array("host5", "host6", "host7", "host8"))
-    statusTracker.handleHeartbeatResponse(response4)
-
-    // availableWorkers wont update through heartbeat
-    // when DYNAMIC_RESOURCE_ENABLE set to false
-    Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(), 
0)
-    // available workers won't overwrite excluded workers
-    Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
-    Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host5")))
-    Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host6")))
-  }
-
-  test("handleHeartbeatResponse with availableWorkers") {
-    val celebornConf = new CelebornConf()
-    celebornConf.set(CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT, 2000L)
-    celebornConf.set(CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED, true)
-    val statusTracker = new WorkerStatusTracker(celebornConf, null)
-
-    val registerTime = System.currentTimeMillis()
-    statusTracker.excludedWorkers.put(mock("host1"), 
(StatusCode.WORKER_UNKNOWN, registerTime))
-    statusTracker.excludedWorkers.put(mock("host2"), 
(StatusCode.WORKER_SHUTDOWN, registerTime))
-
-    // test reserve (only statusCode list in handleHeartbeatResponse)
-    val empty = buildResponse(Array.empty, Array.empty, Array.empty, 
Array.empty)
-    statusTracker.handleHeartbeatResponse(empty)
-
-    // only reserve host1
-    Assert.assertEquals(
-      statusTracker.excludedWorkers.get(mock("host1")),
-      (StatusCode.WORKER_UNKNOWN, registerTime))
-    
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host2")))
-
-    // add shutdown/excluded worker
-    val response1 =
-      buildResponse(Array("host0"), Array("host1", "host3"), Array("host4"), 
Array.empty)
-    statusTracker.handleHeartbeatResponse(response1)
-
-    // test keep Unknown register time
-    Assert.assertEquals(
-      statusTracker.excludedWorkers.get(mock("host1")),
-      (StatusCode.WORKER_UNKNOWN, registerTime))
-    // test new added workers
-    Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host0")))
-    Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host3")))
-    
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
-    Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))
-
-    // test re heartbeat with shutdown workers
-    val response2 = buildResponse(Array.empty, Array.empty, Array("host4"), 
Array.empty)
-    statusTracker.handleHeartbeatResponse(response2)
-    
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
-    Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))
-
-    // test remove
-    val workers = new util.HashSet[WorkerInfo]
-    workers.add(mock("host3"))
-    statusTracker.removeFromExcludedWorkers(workers)
-    
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host3")))
-
-    // test register time elapsed
-    Thread.sleep(3000)
-    val response3 = buildResponse(Array.empty, Array("host5", "host6"), 
Array.empty, Array.empty)
-    statusTracker.handleHeartbeatResponse(response3)
-    Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
-    
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host1")))
-
-    // test available workers
-    Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(), 
0)
-    val response4 = buildResponse(
-      Array.empty,
-      Array.empty,
-      Array.empty,
-      Array("host5", "host6", "host7", "host8"))
-    statusTracker.handleHeartbeatResponse(response4)
-
-    // availableWorkers wont update with excludedWorkers
-    // So before using them we hava to filter excludedWorkers
-    Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(), 
4)
-    // available workers won't overwrite excluded workers
-    Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
-    Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host5")))
-    Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host6")))
-
-    // test re heartbeat with available workers
-    val response5 = buildResponse(Array.empty, Array.empty, Array.empty, 
Array("host8", "host9"))
-    statusTracker.handleHeartbeatResponse(response5)
-    Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(), 
2)
-    
Assert.assertFalse(statusTracker.availableWorkersWithoutEndpoint.contains(mock("host7")))
-    
Assert.assertTrue(statusTracker.availableWorkersWithoutEndpoint.contains(mock("host8")))
-    
Assert.assertTrue(statusTracker.availableWorkersWithoutEndpoint.contains(mock("host9")))
   }
 
   private def buildResponse(
       excludedWorkerHosts: Array[String],
       unknownWorkerHosts: Array[String],
-      shuttingWorkerHosts: Array[String],
-      availableWorkerHosts: Array[String]): HeartbeatFromApplicationResponse = 
{
+      shuttingWorkerHosts: Array[String]): HeartbeatFromApplicationResponse = {
     val excludedWorkers = mockWorkers(excludedWorkerHosts)
     val unknownWorkers = mockWorkers(unknownWorkerHosts)
     val shuttingWorkers = mockWorkers(shuttingWorkerHosts)
-    val availableWorkers = mockWorkers(availableWorkerHosts)
     HeartbeatFromApplicationResponse(
       StatusCode.SUCCESS,
       excludedWorkers,
       unknownWorkers,
       shuttingWorkers,
-      availableWorkers,
       new util.ArrayList[Integer](),
       null)
   }
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index 3ed029676..9ed21d180 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -442,7 +442,6 @@ message PbHeartbeatFromApplication {
   string requestId = 4;
   repeated PbWorkerInfo needCheckedWorkerList = 5;
   bool shouldResponse = 6;
-  bool needAvailableWorkers = 7;
 }
 
 message PbHeartbeatFromApplicationResponse {
@@ -451,8 +450,7 @@ message PbHeartbeatFromApplicationResponse {
   repeated PbWorkerInfo unknownWorkers = 3;
   repeated PbWorkerInfo shuttingWorkers = 4;
   repeated int32 registeredShuffles = 5;
-  repeated PbWorkerInfo availableWorkers = 6;
-  PbCheckQuotaResponse checkQuotaResponse = 7;
+  PbCheckQuotaResponse checkQuotaResponse = 6;
 }
 
 message PbCheckQuota {
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 046c9c95a..a10d58daa 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -901,6 +901,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable 
with Logging with Se
   def clientCommitFilesIgnoreExcludedWorkers: Boolean = 
get(CLIENT_COMMIT_IGNORE_EXCLUDED_WORKERS)
   def clientShuffleDynamicResourceEnabled: Boolean =
     get(CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED)
+  def clientShuffleDynamicResourceFactor: Double = 
get(CLIENT_SHUFFLE_DYNAMIC_RESOURCE_FACTOR)
   def appHeartbeatTimeoutMs: Long = get(APPLICATION_HEARTBEAT_TIMEOUT)
   def hdfsExpireDirsTimeoutMS: Long = get(HDFS_EXPIRE_DIRS_TIMEOUT)
   def dfsExpireDirsTimeoutMS: Long = get(DFS_EXPIRE_DIRS_TIMEOUT)
@@ -4845,6 +4846,17 @@ object CelebornConf extends Logging {
       .booleanConf
       .createWithDefault(false)
 
+  val CLIENT_SHUFFLE_DYNAMIC_RESOURCE_FACTOR: ConfigEntry[Double] =
+    buildConf("celeborn.client.shuffle.dynamicResourceFactor")
+      .categories("client")
+      .version("0.6.0")
+      .doc("The ChangePartitionManager will check whether (unavailable workers 
/ shuffle allocated workers) " +
+        "is more than the factor before obtaining candidate workers from the 
requestSlots RPC response" +
+        "when ${CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED.key} set true")
+      .doubleConf
+      .checkValue(v => v >= 0.0 && v <= 1.0, "Should be in [0.0, 1.0].")
+      .createWithDefault(0.5)
+
   val CLIENT_PUSH_STAGE_END_TIMEOUT: ConfigEntry[Long] =
     buildConf("celeborn.client.push.stageEnd.timeout")
       .withAlternative("celeborn.push.stageEnd.timeout")
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 57086c035..4a96184b6 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -413,7 +413,6 @@ object ControlMessages extends Logging {
       totalWritten: Long,
       fileCount: Long,
       needCheckedWorkerList: util.List[WorkerInfo],
-      needAvailableWorkers: Boolean,
       override var requestId: String = ZERO_UUID,
       shouldResponse: Boolean = false) extends MasterRequestMessage
 
@@ -422,7 +421,6 @@ object ControlMessages extends Logging {
       excludedWorkers: util.List[WorkerInfo],
       unknownWorkers: util.List[WorkerInfo],
       shuttingWorkers: util.List[WorkerInfo],
-      availableWorkers: util.List[WorkerInfo],
       registeredShuffles: util.List[Integer],
       checkQuotaResponse: CheckQuotaResponse) extends Message
 
@@ -812,7 +810,6 @@ object ControlMessages extends Logging {
           totalWritten,
           fileCount,
           needCheckedWorkerList,
-          needAvailableWorkers,
           requestId,
           shouldResponse) =>
       val payload = PbHeartbeatFromApplication.newBuilder()
@@ -822,7 +819,6 @@ object ControlMessages extends Logging {
         .setFileCount(fileCount)
         .addAllNeedCheckedWorkerList(needCheckedWorkerList.asScala.map(
           PbSerDeUtils.toPbWorkerInfo(_, true, true)).toList.asJava)
-        .setNeedAvailableWorkers(needAvailableWorkers)
         .setShouldResponse(shouldResponse)
         .build().toByteArray
       new TransportMessage(MessageType.HEARTBEAT_FROM_APPLICATION, payload)
@@ -832,7 +828,6 @@ object ControlMessages extends Logging {
           excludedWorkers,
           unknownWorkers,
           shuttingWorkers,
-          availableWorkers,
           registeredShuffles,
           checkQuotaResponse) =>
       val pbCheckQuotaResponse = 
PbCheckQuotaResponse.newBuilder().setAvailable(
@@ -845,8 +840,6 @@ object ControlMessages extends Logging {
           unknownWorkers.asScala.map(PbSerDeUtils.toPbWorkerInfo(_, true, 
true)).toList.asJava)
         .addAllShuttingWorkers(
           shuttingWorkers.asScala.map(PbSerDeUtils.toPbWorkerInfo(_, true, 
true)).toList.asJava)
-        .addAllAvailableWorkers(
-          availableWorkers.asScala.map(PbSerDeUtils.toPbWorkerInfo(_, true, 
true)).toList.asJava)
         .addAllRegisteredShuffles(registeredShuffles)
         .setCheckQuotaResponse(pbCheckQuotaResponse)
         .build().toByteArray
@@ -1219,7 +1212,6 @@ object ControlMessages extends Logging {
           new util.ArrayList[WorkerInfo](
             pbHeartbeatFromApplication.getNeedCheckedWorkerListList.asScala
               .map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava),
-          pbHeartbeatFromApplication.getNeedAvailableWorkers,
           pbHeartbeatFromApplication.getRequestId,
           pbHeartbeatFromApplication.getShouldResponse)
 
@@ -1235,8 +1227,6 @@ object ControlMessages extends Logging {
             .map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava,
           pbHeartbeatFromApplicationResponse.getShuttingWorkersList.asScala
             .map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava,
-          pbHeartbeatFromApplicationResponse.getAvailableWorkersList.asScala
-            .map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava,
           pbHeartbeatFromApplicationResponse.getRegisteredShufflesList,
           CheckQuotaResponse(pbCheckQuotaResponse.getAvailable, 
pbCheckQuotaResponse.getReason))
 
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 7dad000b0..fd1160ff3 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -96,6 +96,7 @@ license: |
 | celeborn.client.shuffle.compression.zstd.level | 1 | false | Compression 
level for Zstd compression codec, its value should be an integer between -5 and 
22. Increasing the compression level will result in better compression at the 
expense of more CPU and memory. | 0.3.0 | 
celeborn.shuffle.compression.zstd.level | 
 | celeborn.client.shuffle.decompression.lz4.xxhash.instance | 
&lt;undefined&gt; | false | Decompression XXHash instance for Lz4. Available 
options: JNI, JAVASAFE, JAVAUNSAFE. | 0.3.2 |  | 
 | celeborn.client.shuffle.dynamicResourceEnabled | false | false | When 
enabled, the ChangePartitionManager will obtain candidate workers from the 
availableWorkers pool “ +during heartbeats when worker resource change. | 0.6.0 
|  | 
+| celeborn.client.shuffle.dynamicResourceFactor | 0.5 | false | The 
ChangePartitionManager will check whether (unavailable workers / shuffle 
allocated workers) is more than the factor before obtaining candidate workers 
from the requestSlots RPC responsewhen 
${CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED.key} set true | 0.6.0 |  | 
 | celeborn.client.shuffle.expired.checkInterval | 60s | false | Interval for 
client to check expired shuffles. | 0.3.0 | 
celeborn.shuffle.expired.checkInterval | 
 | celeborn.client.shuffle.manager.port | 0 | false | Port used by the 
LifecycleManager on the Driver. | 0.3.0 | celeborn.shuffle.manager.port | 
 | celeborn.client.shuffle.mapPartition.split.enabled | false | false | whether 
to enable shuffle partition split. Currently, this only applies to 
MapPartition. | 0.3.1 |  | 
diff --git 
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala 
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
index dae240977..117b4875a 100644
--- 
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
+++ 
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
@@ -392,7 +392,6 @@ private[celeborn] class Master(
           totalWritten,
           fileCount,
           needCheckedWorkerList,
-          needAvailableWorkers,
           requestId,
           shouldResponse) =>
       logDebug(s"Received heartbeat from app $appId")
@@ -405,7 +404,6 @@ private[celeborn] class Master(
           totalWritten,
           fileCount,
           needCheckedWorkerList,
-          needAvailableWorkers,
           requestId,
           shouldResponse))
 
@@ -1097,7 +1095,6 @@ private[celeborn] class Master(
       totalWritten: Long,
       fileCount: Long,
       needCheckedWorkerList: util.List[WorkerInfo],
-      needAvailableWorkers: Boolean,
       requestId: String,
       shouldResponse: Boolean): Unit = {
     statusSystem.handleAppHeartbeat(
@@ -1111,13 +1108,7 @@ private[celeborn] class Master(
     if (shouldResponse) {
       // UserResourceConsumption and DiskInfo are eliminated from WorkerInfo
       // during serialization of HeartbeatFromApplicationResponse
-      var availableWorksSentToClient = new util.ArrayList[WorkerInfo]()
-      if (needAvailableWorkers) {
-        availableWorksSentToClient = new util.ArrayList[WorkerInfo](
-          statusSystem.workersMap.values().asScala.filter(worker =>
-            statusSystem.isWorkerAvailable(worker)).toList.asJava)
-      }
-      val appRelatedShuffles =
+      var appRelatedShuffles =
         statusSystem.registeredAppAndShuffles.getOrDefault(appId, 
Collections.emptySet())
       context.reply(HeartbeatFromApplicationResponse(
         StatusCode.SUCCESS,
@@ -1126,7 +1117,6 @@ private[celeborn] class Master(
         unknownWorkers,
         new util.ArrayList[WorkerInfo](
           (statusSystem.shutdownWorkers.asScala ++ 
statusSystem.decommissionWorkers.asScala).asJava),
-        availableWorksSentToClient,
         new util.ArrayList(appRelatedShuffles),
         CheckQuotaResponse(isAvailable = true, "")))
     } else {
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala
index bee3b61ec..3c6547c1c 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala
@@ -19,14 +19,14 @@ package org.apache.celeborn.tests.client
 
 import java.util
 
-import scala.collection.JavaConverters.mapAsScalaMapConverter
+import scala.collection.JavaConverters.{asScalaSetConverter, 
mapAsScalaMapConverter}
 
 import org.apache.celeborn.client.{ChangePartitionManager, 
ChangePartitionRequest, LifecycleManager, WithShuffleClientSuite}
 import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, 
WorkerInfo}
 import org.apache.celeborn.common.protocol.message.StatusCode
-import org.apache.celeborn.common.util.{CelebornExitKind, JavaUtils}
+import org.apache.celeborn.common.util.JavaUtils
 import org.apache.celeborn.service.deploy.MiniClusterFeature
 
 class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite
@@ -37,6 +37,9 @@ class ChangePartitionManagerUpdateWorkersSuite extends 
WithShuffleClientSuite
 
   override def beforeAll(): Unit = {
     super.beforeAll()
+  }
+
+  override def beforeEach(): Unit = {
     val testConf = Map(
       s"${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}" -> "3")
     val (master, _) = setupMiniClusterWithRandomPorts(testConf, testConf, 
workerNum = 1)
@@ -45,12 +48,13 @@ class ChangePartitionManagerUpdateWorkersSuite extends 
WithShuffleClientSuite
       master.conf.get(CelebornConf.MASTER_ENDPOINTS.key))
   }
 
-  test("test changePartition with available workers") {
+  test("test changePartition with available workers expansion") {
     val shuffleId = nextShuffleId
     val conf = celebornConf.clone
     conf.set(CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key, "3")
-      .set(CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED.key, "true")
       .set(CelebornConf.CLIENT_BATCH_HANDLE_CHANGE_PARTITION_ENABLED.key, 
"false")
+      .set(CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED.key, "true")
+      .set(CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_FACTOR.key, "0.0")
 
     val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf)
     val changePartitionManager: ChangePartitionManager =
@@ -74,8 +78,6 @@ class ChangePartitionManagerUpdateWorkersSuite extends 
WithShuffleClientSuite
       res.workerResource,
       updateEpoch = false)
 
-    val slots = res.workerResource
-    val candidatesWorkers = new util.HashSet(slots.keySet())
     if (reserveSlotsSuccess) {
       val allocatedWorkers =
         JavaUtils.newConcurrentHashMap[WorkerInfo, 
ShufflePartitionLocationInfo]()
@@ -85,21 +87,15 @@ class ChangePartitionManagerUpdateWorkersSuite extends 
WithShuffleClientSuite
           partitionLocationInfo.addPrimaryPartitions(primaryLocations)
           partitionLocationInfo.addReplicaPartitions(replicaLocations)
           allocatedWorkers.put(workerInfo, partitionLocationInfo)
+          lifecycleManager.updateLatestPartitionLocations(shuffleId, 
primaryLocations)
       }
       lifecycleManager.shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
-      
lifecycleManager.workerStatusTracker.addWorkersWithEndpoint(candidatesWorkers)
     }
     assert(lifecycleManager.workerSnapshots(shuffleId).size() == 1)
-    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() 
== 1)
-    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size()
 == 0)
 
     // total workerNum is 1 + 2 = 3 now
     setUpWorkers(workerConfForAdding, 2)
-    // longer than APPLICATION_HEARTBEAT_INTERVAL 10s
-    Thread.sleep(15000)
     assert(workerInfos.size == 3)
-    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() 
== 1)
-    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size()
 == 2)
 
     0 until 10 foreach { partitionId: Int =>
       val req = ChangePartitionRequest(
@@ -117,28 +113,185 @@ class ChangePartitionManagerUpdateWorkersSuite extends 
WithShuffleClientSuite
         Array(req),
         lifecycleManager.commitManager.isSegmentGranularityVisible(shuffleId))
     }
-    assert(
-      
lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size() + 
lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() == 3)
-
+    Thread.sleep(5000)
     assert(lifecycleManager.workerSnapshots(shuffleId).size() > 1)
-    assert(
-      lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() 
> 1)
+
+    lifecycleManager.stop()
+  }
+
+  test("test changePartition with available workers shrink") {
+    setUpWorkers(workerConfForAdding, 1)
+    // total workers 1 + 1 = 2
+    assert(workerInfos.size == 2)
+    val shuffleId = nextShuffleId
+    val conf = celebornConf.clone
+    conf.set(CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key, "3")
+      .set(CelebornConf.CLIENT_BATCH_HANDLE_CHANGE_PARTITION_ENABLED.key, 
"false")
+      .set(CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED.key, "true")
+      .set(CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_FACTOR.key, "0.5")
+
+    val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf)
+    val changePartitionManager: ChangePartitionManager =
+      new ChangePartitionManager(conf, lifecycleManager)
+    val ids = new util.ArrayList[Integer](10)
+    0 until 10 foreach {
+      ids.add(_)
+    }
+    val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 
ids)
+    assert(res.status == StatusCode.SUCCESS)
+    val workerNum = res.workerResource.keySet().size()
+    assert(workerNum == 2)
+
+    lifecycleManager.setupEndpoints(
+      res.workerResource.keySet(),
+      shuffleId,
+      new ShuffleFailedWorkers())
+
+    val reserveSlotsSuccess = lifecycleManager.reserveSlotsWithRetry(
+      shuffleId,
+      new util.HashSet(res.workerResource.keySet()),
+      res.workerResource,
+      updateEpoch = false)
+
+    if (reserveSlotsSuccess) {
+      val allocatedWorkers =
+        JavaUtils.newConcurrentHashMap[WorkerInfo, 
ShufflePartitionLocationInfo]()
+      res.workerResource.asScala.foreach {
+        case (workerInfo, (primaryLocations, replicaLocations)) =>
+          val partitionLocationInfo = new ShufflePartitionLocationInfo()
+          partitionLocationInfo.addPrimaryPartitions(primaryLocations)
+          partitionLocationInfo.addReplicaPartitions(replicaLocations)
+          allocatedWorkers.put(workerInfo, partitionLocationInfo)
+          lifecycleManager.updateLatestPartitionLocations(shuffleId, 
primaryLocations)
+      }
+      lifecycleManager.shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
+    }
+    assert(lifecycleManager.workerSnapshots(shuffleId).size() == workerNum)
 
     // shut down workers test
+    // total workerNum is 2 - 1 = 1 now
     val workerInfoList = workerInfos.toList
-    0 until 2 foreach { index =>
+    0 until 1 foreach { index =>
       val (worker, _) = workerInfoList(index)
-      worker.stop(CelebornExitKind.EXIT_IMMEDIATELY)
-      worker.rpcEnv.shutdown()
       // Workers in miniClusterFeature wont update status with master through 
heartbeat.
       // So update status manually.
       masterInfo._1.statusSystem.excludedWorkers.add(worker.workerInfo)
-      workerInfos.remove(worker)
+
+      val failedWorker = new ShuffleFailedWorkers()
+      failedWorker.put(
+        worker.workerInfo,
+        (StatusCode.RESERVE_SLOTS_FAILED, System.currentTimeMillis()))
+      lifecycleManager.workerStatusTracker.recordWorkerFailure(failedWorker)
     }
 
-    Thread.sleep(15000)
-    assert(
-      lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() 
+ lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size() 
== 1)
+    val tmpSnapshotWorkers =
+      lifecycleManager
+        .workerSnapshots(shuffleId)
+        .keySet()
+        .asScala
+        .filter(lifecycleManager.workerStatusTracker.workerAvailable)
+    assert(tmpSnapshotWorkers.size == 1)
+
+    // add another new worker
+    setUpWorkers(workerConfForAdding, 1)
+    assert(workerInfos.size == 3)
+
+    0 until 10 foreach { partitionId: Int =>
+      val req = ChangePartitionRequest(
+        null,
+        shuffleId,
+        partitionId,
+        -1,
+        null,
+        None)
+      changePartitionManager.changePartitionRequests.computeIfAbsent(
+        shuffleId,
+        changePartitionManager.rpcContextRegisterFunc)
+      changePartitionManager.handleRequestPartitions(
+        shuffleId,
+        Array(req),
+        lifecycleManager.commitManager.isSegmentGranularityVisible(shuffleId))
+    }
+
+    val snapshotCandidates =
+      lifecycleManager
+        .workerSnapshots(shuffleId)
+        .keySet()
+        .asScala
+        .filter(lifecycleManager.workerStatusTracker.workerAvailable)
+
+    assert(snapshotCandidates.size == 2)
+    lifecycleManager.stop()
+  }
+
+  test("test changePartition with available workers and factor") {
+    val shuffleId = nextShuffleId
+    val conf = celebornConf.clone
+    conf.set(CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key, "3")
+      .set(CelebornConf.CLIENT_BATCH_HANDLE_CHANGE_PARTITION_ENABLED.key, 
"false")
+      .set(CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED.key, "true")
+      .set(CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_FACTOR.key, "1.0")
+
+    val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf)
+    val changePartitionManager: ChangePartitionManager =
+      new ChangePartitionManager(conf, lifecycleManager)
+    val ids = new util.ArrayList[Integer](10)
+    0 until 10 foreach {
+      ids.add(_)
+    }
+    val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 
ids)
+    assert(res.status == StatusCode.SUCCESS)
+
+    // workerNum is 1
+    val workerNum = res.workerResource.keySet().size()
+    assert(workerNum == 1)
+
+    lifecycleManager.setupEndpoints(
+      res.workerResource.keySet(),
+      shuffleId,
+      new ShuffleFailedWorkers())
+
+    val reserveSlotsSuccess = lifecycleManager.reserveSlotsWithRetry(
+      shuffleId,
+      new util.HashSet(res.workerResource.keySet()),
+      res.workerResource,
+      updateEpoch = false)
+
+    if (reserveSlotsSuccess) {
+      val allocatedWorkers =
+        JavaUtils.newConcurrentHashMap[WorkerInfo, 
ShufflePartitionLocationInfo]()
+      res.workerResource.asScala.foreach {
+        case (workerInfo, (primaryLocations, replicaLocations)) =>
+          val partitionLocationInfo = new ShufflePartitionLocationInfo()
+          partitionLocationInfo.addPrimaryPartitions(primaryLocations)
+          partitionLocationInfo.addReplicaPartitions(replicaLocations)
+          allocatedWorkers.put(workerInfo, partitionLocationInfo)
+      }
+      lifecycleManager.shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
+    }
+    assert(lifecycleManager.workerSnapshots(shuffleId).size() == workerNum)
+
+    // total workerNum is 1 + 2 = 3 now
+    setUpWorkers(workerConfForAdding, 2)
+    assert(workerInfos.size == 3)
+
+    0 until 10 foreach { partitionId: Int =>
+      val req = ChangePartitionRequest(
+        null,
+        shuffleId,
+        partitionId,
+        -1,
+        null,
+        None)
+      changePartitionManager.changePartitionRequests.computeIfAbsent(
+        shuffleId,
+        changePartitionManager.rpcContextRegisterFunc)
+      changePartitionManager.handleRequestPartitions(
+        shuffleId,
+        Array(req),
+        lifecycleManager.commitManager.isSegmentGranularityVisible(shuffleId))
+    }
+    assert(lifecycleManager.workerSnapshots(shuffleId).size() == workerNum)
 
     lifecycleManager.stop()
   }
@@ -160,7 +313,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends 
WithShuffleClientSuite
     val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 
ids)
     assert(res.status == StatusCode.SUCCESS)
 
-    // workerNum is 1 (after 1 add 2 and stop 2)
+    // workerNum is 1
     val workerNum = res.workerResource.keySet().size()
     assert(workerNum == 1)
 
@@ -175,8 +328,6 @@ class ChangePartitionManagerUpdateWorkersSuite extends 
WithShuffleClientSuite
       res.workerResource,
       updateEpoch = false)
 
-    val slots = res.workerResource
-    val candidatesWorkers = new util.HashSet(slots.keySet())
     if (reserveSlotsSuccess) {
       val allocatedWorkers =
         JavaUtils.newConcurrentHashMap[WorkerInfo, 
ShufflePartitionLocationInfo]()
@@ -188,18 +339,13 @@ class ChangePartitionManagerUpdateWorkersSuite extends 
WithShuffleClientSuite
           allocatedWorkers.put(workerInfo, partitionLocationInfo)
       }
       lifecycleManager.shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
-      
lifecycleManager.workerStatusTracker.addWorkersWithEndpoint(candidatesWorkers)
     }
     assert(lifecycleManager.workerSnapshots(shuffleId).size() == workerNum)
-    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() 
== workerNum)
 
     // total workerNum is 1 + 2 = 3 now
     setUpWorkers(workerConfForAdding, 2)
     // longer than APPLICATION_HEARTBEAT_INTERVAL 10s
-    Thread.sleep(15000)
     assert(workerInfos.size == 3)
-    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() 
== workerNum)
-    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size()
 == 0)
 
     0 until 10 foreach { partitionId: Int =>
       val req = ChangePartitionRequest(
@@ -217,16 +363,13 @@ class ChangePartitionManagerUpdateWorkersSuite extends 
WithShuffleClientSuite
         Array(req),
         lifecycleManager.commitManager.isSegmentGranularityVisible(shuffleId))
     }
-    logInfo(s"reallocated worker num: ${res.workerResource.keySet().size()}; 
workerInfo: ${res.workerResource.keySet()}")
     assert(lifecycleManager.workerSnapshots(shuffleId).size() == workerNum)
-    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() 
== workerNum)
-    
assert(lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size()
 == 0)
 
     lifecycleManager.stop()
   }
 
-  override def afterAll(): Unit = {
-    logInfo("all test complete , stop celeborn mini cluster")
+  override def afterEach(): Unit = {
+    logInfo("test complete, stop celeborn mini cluster")
     shutdownMiniCluster()
   }
 }
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
index 4e3b7e6ff..85de0bf20 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
@@ -57,7 +57,7 @@ class RetryReviveTest extends AnyFunSuite
   }
 
   test(
-    "celeborn spark integration test - e2e test retry revive with available 
workers from heartbeat") {
+    "celeborn spark integration test - e2e test retry revive with new 
allocated workers from RPC") {
     val testConf = Map(
       s"${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}" -> "3",
       s"${CelebornConf.MASTER_SLOT_ASSIGN_EXTRA_SLOTS.key}" -> "0")
@@ -66,8 +66,8 @@ class RetryReviveTest extends AnyFunSuite
     val sparkConf = new SparkConf()
       .set(s"spark.${CelebornConf.TEST_CLIENT_RETRY_REVIVE.key}", "true")
       .set(s"spark.${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}", "3")
-      .set(s"spark.${CelebornConf.CLIENT_SLOT_ASSIGN_MAX_WORKERS.key}", "1")
       
.set(s"spark.${CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED.key}", 
"true")
+      
.set(s"spark.${CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_FACTOR.key}", "0")
       .set(s"spark.${CelebornConf.MASTER_SLOT_ASSIGN_EXTRA_SLOTS.key}", "0")
       .setAppName("celeborn-demo").setMaster("local[2]")
     val ss = SparkSession.builder()

Reply via email to