This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 7685fa7db [CELEBORN-1636] Client supports dynamic update of Worker
resources on the server
7685fa7db is described below
commit 7685fa7db22a156d42f8824192ccd6264d351de7
Author: szt <[email protected]>
AuthorDate: Mon Oct 28 09:49:31 2024 +0800
[CELEBORN-1636] Client supports dynamic update of Worker resources on the
server
### What changes were proposed in this pull request?
Currently, the ChangePartitionManager retrieves workers from the
LifeCycleManager's workerSnapshot. However, during the revival process in
reallocateChangePartitionRequestSlotsFromCandidates, it does not account for
newly added available workers resulting from elastic contraction and expansion.
This PR addresses this issue by updating the candidate workers in the
ChangePartitionManager to use the available workers reported in the heartbeat
from the master.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
Closes #2835 from zaynt4606/clbdev.
Authored-by: szt <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../celeborn/client/ApplicationHeartbeater.scala | 3 +
.../celeborn/client/ChangePartitionManager.scala | 126 +++++++----
.../apache/celeborn/client/LifecycleManager.scala | 8 +-
.../celeborn/client/WorkerStatusTracker.scala | 61 +++++-
.../celeborn/client/WorkerStatusTrackerSuite.scala | 121 ++++++++++-
common/src/main/proto/TransportMessages.proto | 2 +
.../org/apache/celeborn/common/CelebornConf.scala | 11 +
.../common/protocol/message/ControlMessages.scala | 10 +
docs/configuration/client.md | 1 +
.../celeborn/service/deploy/master/Master.scala | 10 +
.../ChangePartitionManagerUpdateWorkersSuite.scala | 232 +++++++++++++++++++++
.../client/LifecycleManagerCommitFilesSuite.scala | 10 +-
.../client/LifecycleManagerDestroySlotsSuite.scala | 15 +-
.../LifecycleManagerSetupEndpointSuite.scala | 4 +-
.../celeborn/tests/spark/RetryReviveTest.scala | 30 ++-
.../service/deploy/MiniClusterFeature.scala | 22 +-
16 files changed, 589 insertions(+), 77 deletions(-)
diff --git
a/client/src/main/scala/org/apache/celeborn/client/ApplicationHeartbeater.scala
b/client/src/main/scala/org/apache/celeborn/client/ApplicationHeartbeater.scala
index 0b850ff52..4b90a8bca 100644
---
a/client/src/main/scala/org/apache/celeborn/client/ApplicationHeartbeater.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/ApplicationHeartbeater.scala
@@ -47,6 +47,7 @@ class ApplicationHeartbeater(
// Use independent app heartbeat threads to avoid being blocked by other
operations.
private val appHeartbeatIntervalMs = conf.appHeartbeatIntervalMs
private val applicationUnregisterEnabled = conf.applicationUnregisterEnabled
+ private val clientShuffleDynamicResourceEnabled =
conf.clientShuffleDynamicResourceEnabled
private val appHeartbeatHandlerThread =
ThreadUtils.newDaemonSingleThreadScheduledExecutor(
"celeborn-client-lifecycle-manager-app-heartbeater")
@@ -69,6 +70,7 @@ class ApplicationHeartbeater(
tmpTotalWritten,
tmpTotalFileCount,
workerStatusTracker.getNeedCheckedWorkers().toList.asJava,
+ clientShuffleDynamicResourceEnabled,
ZERO_UUID,
true)
val response = requestHeartbeat(appHeartbeat)
@@ -129,6 +131,7 @@ class ApplicationHeartbeater(
List.empty.asJava,
List.empty.asJava,
List.empty.asJava,
+ List.empty.asJava,
List.empty.asJava)
}
}
diff --git
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
index 46628aaab..771b51151 100644
---
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
@@ -18,14 +18,15 @@
package org.apache.celeborn.client
import java.util
-import java.util.{Set => JSet}
+import java.util.{function, Set => JSet}
import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService,
ScheduledFuture, TimeUnit}
import scala.collection.JavaConverters._
+import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.meta.WorkerInfo
+import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo,
WorkerInfo}
import org.apache.celeborn.common.protocol.PartitionLocation
import
org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource
import org.apache.celeborn.common.protocol.message.StatusCode
@@ -45,7 +46,8 @@ class ChangePartitionManager(
private val pushReplicateEnabled = conf.clientPushReplicateEnabled
// shuffleId -> (partitionId -> set of ChangePartition)
- private val changePartitionRequests =
+ val changePartitionRequests
+ : ConcurrentHashMap[Int, ConcurrentHashMap[Integer,
JSet[ChangePartitionRequest]]] =
JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Integer,
JSet[ChangePartitionRequest]]]()
// shuffleId -> locks
@@ -74,6 +76,8 @@ class ChangePartitionManager(
private val testRetryRevive = conf.testRetryRevive
+ private val clientShuffleDynamicResourceEnabled =
conf.clientShuffleDynamicResourceEnabled
+
def start(): Unit = {
batchHandleChangePartition = batchHandleChangePartitionSchedulerThread.map
{
// noinspection ConvertExpressionToSAM
@@ -128,7 +132,8 @@ class ChangePartitionManager(
batchHandleChangePartitionSchedulerThread.foreach(ThreadUtils.shutdown(_))
}
- private val rpcContextRegisterFunc =
+ val rpcContextRegisterFunc
+ : function.Function[Int, ConcurrentHashMap[Integer,
JSet[ChangePartitionRequest]]] =
new util.function.Function[
Int,
ConcurrentHashMap[Integer, util.Set[ChangePartitionRequest]]]() {
@@ -148,6 +153,13 @@ class ChangePartitionManager(
}
}
+ private val updateWorkerSnapshotsFunc =
+ new util.function.Function[WorkerInfo, ShufflePartitionLocationInfo] {
+ override def apply(w: WorkerInfo): ShufflePartitionLocationInfo = {
+ new ShufflePartitionLocationInfo()
+ }
+ }
+
def handleRequestPartitionLocation(
context: RequestLocationCallContext,
shuffleId: Int,
@@ -186,7 +198,7 @@ class ChangePartitionManager(
partitionId,
StatusCode.SUCCESS,
Some(latestLoc),
- lifecycleManager.workerStatusTracker.workerAvailable(oldPartition))
+
lifecycleManager.workerStatusTracker.workerAvailableByLocation(oldPartition))
logDebug(s"[handleRequestPartitionLocation]: For shuffle:
$shuffleId," +
s" old partition: $partitionId-$oldEpoch, new partition:
$latestLoc found, return it")
return
@@ -254,7 +266,7 @@ class ChangePartitionManager(
req.partitionId,
StatusCode.SUCCESS,
Option(newLocation),
-
lifecycleManager.workerStatusTracker.workerAvailable(req.oldPartition))))
+
lifecycleManager.workerStatusTracker.workerAvailableByLocation(req.oldPartition))))
}
}
@@ -274,18 +286,49 @@ class ChangePartitionManager(
req.partitionId,
status,
None,
-
lifecycleManager.workerStatusTracker.workerAvailable(req.oldPartition))))
+
lifecycleManager.workerStatusTracker.workerAvailableByLocation(req.oldPartition))))
}
}
- // Get candidate worker that not in excluded worker list of shuffleId
- val candidates =
- lifecycleManager
- .workerSnapshots(shuffleId)
- .keySet()
+ val candidates = new util.HashSet[WorkerInfo]()
+ if (clientShuffleDynamicResourceEnabled) {
+ // availableWorkers wont filter excludedWorkers in heartBeat So have to
do filtering.
+ candidates.addAll(lifecycleManager
+ .workerStatusTracker
+ .availableWorkersWithEndpoint
+ .values()
.asScala
+ .toSet
.filter(lifecycleManager.workerStatusTracker.workerAvailable)
- .toList
+ .asJava)
+
+ // SetupEndpoint for those availableWorkers without endpoint
+ val workersRequireEndpoints = new util.HashSet[WorkerInfo](
+
lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.asScala.filter(
+ lifecycleManager.workerStatusTracker.workerAvailable).asJava)
+ val connectFailedWorkers = new ShuffleFailedWorkers()
+ lifecycleManager.setupEndpoints(
+ workersRequireEndpoints,
+ shuffleId,
+ connectFailedWorkers)
+
workersRequireEndpoints.removeAll(connectFailedWorkers.asScala.keys.toList.asJava)
+ candidates.addAll(workersRequireEndpoints)
+
+ // Update worker status
+
lifecycleManager.workerStatusTracker.addWorkersWithEndpoint(workersRequireEndpoints)
+
lifecycleManager.workerStatusTracker.recordWorkerFailure(connectFailedWorkers)
+
lifecycleManager.workerStatusTracker.removeFromExcludedWorkers(candidates)
+ } else {
+ val snapshotCandidates =
+ lifecycleManager
+ .workerSnapshots(shuffleId)
+ .keySet()
+ .asScala
+ .filter(lifecycleManager.workerStatusTracker.workerAvailable)
+ .asJava
+ candidates.addAll(snapshotCandidates)
+ }
+
if (candidates.size < 1 || (pushReplicateEnabled && candidates.size < 2)) {
logError("[Update partition] failed for not enough candidates for
revive.")
replyFailure(StatusCode.SLOT_NOT_AVAILABLE)
@@ -294,11 +337,13 @@ class ChangePartitionManager(
// PartitionSplit all contains oldPartition
val newlyAllocatedLocations =
-
reallocateChangePartitionRequestSlotsFromCandidates(changePartitions.toList,
candidates)
+ reallocateChangePartitionRequestSlotsFromCandidates(
+ changePartitions.toList,
+ candidates.asScala.toList)
if (!lifecycleManager.reserveSlotsWithRetry(
shuffleId,
- new util.HashSet(candidates.toSet.asJava),
+ candidates,
newlyAllocatedLocations,
isSegmentGranularityVisible = isSegmentGranularityVisible)) {
logError(s"[Update partition] failed for $shuffleId.")
@@ -306,33 +351,32 @@ class ChangePartitionManager(
return
}
- val newPrimaryLocations =
- newlyAllocatedLocations.asScala.flatMap {
- case (workInfo, (primaryLocations, replicaLocations)) =>
- // Add all re-allocated slots to worker snapshots.
- lifecycleManager.workerSnapshots(shuffleId).asScala
- .get(workInfo)
- .foreach { partitionLocationInfo =>
- partitionLocationInfo.addPrimaryPartitions(primaryLocations)
- lifecycleManager.updateLatestPartitionLocations(shuffleId,
primaryLocations)
- partitionLocationInfo.addReplicaPartitions(replicaLocations)
- }
- // partition location can be null when call reserveSlotsWithRetry().
- val locations = (primaryLocations.asScala ++
replicaLocations.asScala.map(_.getPeer))
- .distinct.filter(_ != null)
- if (locations.nonEmpty) {
- val changes = locations.map { partition =>
- s"(partition ${partition.getId} epoch from ${partition.getEpoch
- 1} to ${partition.getEpoch})"
- }.mkString("[", ", ", "]")
- logInfo(s"[Update partition] success for " +
- s"shuffle $shuffleId, succeed partitions: " +
- s"$changes.")
- }
+ val newPrimaryLocations = newlyAllocatedLocations.asScala.flatMap {
+ case (workInfo, (primaryLocations, replicaLocations)) =>
+ // Add all re-allocated slots to worker snapshots.
+ val partitionLocationInfo =
lifecycleManager.workerSnapshots(shuffleId).computeIfAbsent(
+ workInfo,
+ updateWorkerSnapshotsFunc)
+ partitionLocationInfo.addPrimaryPartitions(primaryLocations)
+ partitionLocationInfo.addReplicaPartitions(replicaLocations)
+ lifecycleManager.updateLatestPartitionLocations(shuffleId,
primaryLocations)
+
+ // partition location can be null when call reserveSlotsWithRetry().
+ val locations = (primaryLocations.asScala ++
replicaLocations.asScala.map(_.getPeer))
+ .distinct.filter(_ != null)
+ if (locations.nonEmpty) {
+ val changes = locations.map { partition =>
+ s"(partition ${partition.getId} epoch from ${partition.getEpoch -
1} to ${partition.getEpoch})"
+ }.mkString("[", ", ", "]")
+ logInfo(s"[Update partition] success for " +
+ s"shuffle $shuffleId, succeed partitions: " +
+ s"$changes.")
+ }
- // TODO: should record the new partition locations and acknowledge
the new partitionLocations to downstream task,
- // in scenario the downstream task start early before the upstream
task.
- locations
- }
+ // TODO: should record the new partition locations and acknowledge the
new partitionLocations to downstream task,
+ // in scenario the downstream task start early before the upstream
task.
+ locations
+ }
replySuccess(newPrimaryLocations.toArray)
}
diff --git
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index d35802958..38b9f28ba 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -449,11 +449,11 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
}
def setupEndpoints(
- slots: WorkerResource,
+ workers: util.Set[WorkerInfo],
shuffleId: Int,
connectFailedWorkers: ShuffleFailedWorkers): Unit = {
val futures = new util.LinkedList[(Future[RpcEndpointRef], WorkerInfo)]()
- slots.asScala foreach { case (workerInfo, _) =>
+ workers.asScala foreach { workerInfo =>
val future =
workerRpcEnvInUse.asyncSetupEndpointRefByAddr(RpcEndpointAddress(
RpcAddress.apply(workerInfo.host, workerInfo.rpcPort),
WORKER_EP))
@@ -676,8 +676,7 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
val connectFailedWorkers = new ShuffleFailedWorkers()
// Second, for each worker, try to initialize the endpoint.
- setupEndpoints(slots, shuffleId, connectFailedWorkers)
-
+ setupEndpoints(slots.keySet(), shuffleId, connectFailedWorkers)
candidatesWorkers.removeAll(connectFailedWorkers.asScala.keys.toList.asJava)
workerStatusTracker.recordWorkerFailure(connectFailedWorkers)
// If newly allocated from primary and can setup endpoint success,
LifecycleManager should remove worker from
@@ -713,6 +712,7 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
allocatedWorkers.put(workerInfo, partitionLocationInfo)
}
shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
+ workerStatusTracker.addWorkersWithEndpoint(candidatesWorkers)
registeredShuffle.add(shuffleId)
commitManager.registerShuffle(
shuffleId,
diff --git
a/client/src/main/scala/org/apache/celeborn/client/WorkerStatusTracker.scala
b/client/src/main/scala/org/apache/celeborn/client/WorkerStatusTracker.scala
index d7ccf0fe8..088341813 100644
--- a/client/src/main/scala/org/apache/celeborn/client/WorkerStatusTracker.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/WorkerStatusTracker.scala
@@ -30,17 +30,27 @@ import org.apache.celeborn.common.meta.WorkerInfo
import org.apache.celeborn.common.protocol.PartitionLocation
import
org.apache.celeborn.common.protocol.message.ControlMessages.HeartbeatFromApplicationResponse
import org.apache.celeborn.common.protocol.message.StatusCode
-import org.apache.celeborn.common.util.Utils
+import org.apache.celeborn.common.util.{JavaUtils, Utils}
class WorkerStatusTracker(
conf: CelebornConf,
lifecycleManager: LifecycleManager) extends Logging {
private val excludedWorkerExpireTimeout =
conf.clientExcludedWorkerExpireTimeout
private val workerStatusListeners =
ConcurrentHashMap.newKeySet[WorkerStatusListener]()
+ private val clientShuffleDynamicResourceEnabled =
conf.clientShuffleDynamicResourceEnabled
val excludedWorkers = new ShuffleFailedWorkers()
val shuttingWorkers: JSet[WorkerInfo] = new JHashSet[WorkerInfo]()
+ // Workers that have already set an endpoint can skip the setupEndpoint
process in changePartition when reviving
+ // key: WorkerInfo.toUniqueId value: WorkerInfo
+ val availableWorkersWithEndpoint: ConcurrentHashMap[String, WorkerInfo] =
+ JavaUtils.newConcurrentHashMap[String, WorkerInfo]()
+
+ // Workers that may be available but have not been used(without endpoint)
+ // availableWorkersWithoutEndpoint is empty until
appHeartbeatWithAvailableWorkers set to true
+ val availableWorkersWithoutEndpoint =
ConcurrentHashMap.newKeySet[WorkerInfo]()
+
def registerWorkerStatusListener(workerStatusListener:
WorkerStatusListener): Unit = {
workerStatusListeners.add(workerStatusListener)
}
@@ -61,7 +71,7 @@ class WorkerStatusTracker(
!excludedWorkers.containsKey(worker) && !shuttingWorkers.contains(worker)
}
- def workerAvailable(loc: PartitionLocation): Boolean = {
+ def workerAvailableByLocation(loc: PartitionLocation): Boolean = {
if (loc == null) {
false
} else {
@@ -131,13 +141,16 @@ class WorkerStatusTracker(
failedWorkers.asScala.foreach {
case (worker, (StatusCode.WORKER_SHUTDOWN, _)) =>
shuttingWorkers.add(worker)
+ removeFromAvailableWorkers(worker)
case (worker, (statusCode, registerTime)) if
!excludedWorkers.containsKey(worker) =>
excludedWorkers.put(worker, (statusCode, registerTime))
+ removeFromAvailableWorkers(worker)
case (worker, (statusCode, _))
if statusCode == StatusCode.NO_AVAILABLE_WORKING_DIR ||
statusCode == StatusCode.RESERVE_SLOTS_FAILED ||
statusCode == StatusCode.WORKER_UNKNOWN =>
excludedWorkers.put(worker, (statusCode,
excludedWorkers.get(worker)._2))
+ removeFromAvailableWorkers(worker)
case _ => // Not cover
}
}
@@ -147,10 +160,22 @@ class WorkerStatusTracker(
excludedWorkers.keySet.removeAll(workers)
}
+ private def removeFromAvailableWorkers(worker: WorkerInfo): Unit = {
+ availableWorkersWithEndpoint.remove(worker.toUniqueId())
+ availableWorkersWithoutEndpoint.remove(worker)
+ }
+
+ def addWorkersWithEndpoint(workers: JHashSet[WorkerInfo]): Unit = {
+ availableWorkersWithoutEndpoint.removeAll(workers)
+ workers.asScala.foreach { workerInfo =>
+ availableWorkersWithEndpoint.put(workerInfo.toUniqueId(), workerInfo)
+ }
+ }
+
def handleHeartbeatResponse(res: HeartbeatFromApplicationResponse): Unit = {
if (res.statusCode == StatusCode.SUCCESS) {
logDebug(s"Received Worker status from Primary, excluded workers:
${res.excludedWorkers} " +
- s"unknown workers: ${res.unknownWorkers}, shutdown workers:
${res.shuttingWorkers}")
+ s"unknown workers: ${res.unknownWorkers}, shutdown workers:
${res.shuttingWorkers}, available workers from heartbeat:
${res.availableWorkers}")
val current = System.currentTimeMillis()
var statusChanged = false
@@ -188,9 +213,33 @@ class WorkerStatusTracker(
statusChanged = true
}
}
- val retainResult = shuttingWorkers.retainAll(res.shuttingWorkers)
- val addResult = shuttingWorkers.addAll(res.shuttingWorkers)
- statusChanged = statusChanged || retainResult || addResult
+
+ val retainShuttingWorkersResult =
shuttingWorkers.retainAll(res.shuttingWorkers)
+ val addShuttingWorkersResult =
shuttingWorkers.addAll(res.shuttingWorkers)
+
+ if (clientShuffleDynamicResourceEnabled) {
+ // AvailableWorkers filter Client excludedWorkers and shuttingWorkers.
+ // AvailableWorkers already filtered res.excludedWorkers and
res.shuttingWorkers.
+ val resAvailableWorkers: JSet[WorkerInfo] = new
JHashSet[WorkerInfo](res.availableWorkers)
+ // update availableWorkers
+ // availableWorkers wont filter excludedWorkers.
+ // So before using them we hava to filter excludedWorkers.
+ availableWorkersWithoutEndpoint.retainAll(resAvailableWorkers)
+ availableWorkersWithEndpoint.keySet().retainAll(
+ resAvailableWorkers.asScala.map(_.toUniqueId()).asJava)
+ resAvailableWorkers.asScala.foreach { workerInfo: WorkerInfo =>
+ if
(!availableWorkersWithEndpoint.keySet.contains(workerInfo.toUniqueId())) {
+ availableWorkersWithoutEndpoint.add(workerInfo)
+ } else {
+ if (availableWorkersWithoutEndpoint.contains(workerInfo)) {
+ availableWorkersWithoutEndpoint.remove(workerInfo)
+ }
+ }
+ }
+ }
+
+ statusChanged =
+ statusChanged || retainShuttingWorkersResult ||
addShuttingWorkersResult
// Always trigger commit files for shutting down workers from
HeartbeatFromApplicationResponse
// See details in CELEBORN-696
if (!res.unknownWorkers.isEmpty || !res.shuttingWorkers.isEmpty) {
diff --git
a/client/src/test/scala/org/apache/celeborn/client/WorkerStatusTrackerSuite.scala
b/client/src/test/scala/org/apache/celeborn/client/WorkerStatusTrackerSuite.scala
index d2c21245d..27196e8d9 100644
---
a/client/src/test/scala/org/apache/celeborn/client/WorkerStatusTrackerSuite.scala
+++
b/client/src/test/scala/org/apache/celeborn/client/WorkerStatusTrackerSuite.scala
@@ -23,16 +23,16 @@ import org.junit.Assert
import org.apache.celeborn.CelebornFunSuite
import org.apache.celeborn.common.CelebornConf
-import
org.apache.celeborn.common.CelebornConf.CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT
+import
org.apache.celeborn.common.CelebornConf.{CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT,
CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED}
import org.apache.celeborn.common.meta.WorkerInfo
import
org.apache.celeborn.common.protocol.message.ControlMessages.HeartbeatFromApplicationResponse
import org.apache.celeborn.common.protocol.message.StatusCode
class WorkerStatusTrackerSuite extends CelebornFunSuite {
-
- test("handleHeartbeatResponse") {
+ test("handleHeartbeatResponse without availableWorkers") {
val celebornConf = new CelebornConf()
celebornConf.set(CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT, 2000L)
+ celebornConf.set(CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED, false)
val statusTracker = new WorkerStatusTracker(celebornConf, null)
val registerTime = System.currentTimeMillis()
@@ -40,7 +40,7 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
statusTracker.excludedWorkers.put(mock("host2"),
(StatusCode.WORKER_SHUTDOWN, registerTime))
// test reserve (only statusCode list in handleHeartbeatResponse)
- val empty = buildResponse(Array.empty, Array.empty, Array.empty)
+ val empty = buildResponse(Array.empty, Array.empty, Array.empty,
Array.empty)
statusTracker.handleHeartbeatResponse(empty)
// only reserve host1
@@ -50,7 +50,8 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host2")))
// add shutdown/excluded worker
- val response1 = buildResponse(Array("host0"), Array("host1", "host3"),
Array("host4"))
+ val response1 =
+ buildResponse(Array("host0"), Array("host1", "host3"), Array("host4"),
Array.empty)
statusTracker.handleHeartbeatResponse(response1)
// test keep Unknown register time
@@ -58,15 +59,15 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
statusTracker.excludedWorkers.get(mock("host1")),
(StatusCode.WORKER_UNKNOWN, registerTime))
- // test new added workers
+ // test new added shutdown/excluded workers
Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host0")))
Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host3")))
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))
// test re heartbeat with shutdown workers
- val response3 = buildResponse(Array.empty, Array.empty, Array("host4"))
- statusTracker.handleHeartbeatResponse(response3)
+ val response2 = buildResponse(Array.empty, Array.empty, Array("host4"),
Array.empty)
+ statusTracker.handleHeartbeatResponse(response2)
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))
@@ -78,24 +79,124 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
// test register time elapsed
Thread.sleep(3000)
- val response2 = buildResponse(Array.empty, Array("host5", "host6"),
Array.empty)
+ val response3 = buildResponse(Array.empty, Array("host5", "host6"),
Array.empty, Array.empty)
+ statusTracker.handleHeartbeatResponse(response3)
+ Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
+
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host1")))
+
+ // test available workers
+ Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(),
0)
+ val response4 = buildResponse(
+ Array.empty,
+ Array.empty,
+ Array.empty,
+ Array("host5", "host6", "host7", "host8"))
+ statusTracker.handleHeartbeatResponse(response4)
+
+ // availableWorkers wont update through heartbeat
+ // when DYNAMIC_RESOURCE_ENABLE set to false
+ Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(),
0)
+ // available workers won't overwrite excluded workers
+ Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
+ Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host5")))
+ Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host6")))
+ }
+
+ test("handleHeartbeatResponse with availableWorkers") {
+ val celebornConf = new CelebornConf()
+ celebornConf.set(CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT, 2000L)
+ celebornConf.set(CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED, true)
+ val statusTracker = new WorkerStatusTracker(celebornConf, null)
+
+ val registerTime = System.currentTimeMillis()
+ statusTracker.excludedWorkers.put(mock("host1"),
(StatusCode.WORKER_UNKNOWN, registerTime))
+ statusTracker.excludedWorkers.put(mock("host2"),
(StatusCode.WORKER_SHUTDOWN, registerTime))
+
+ // test reserve (only statusCode list in handleHeartbeatResponse)
+ val empty = buildResponse(Array.empty, Array.empty, Array.empty,
Array.empty)
+ statusTracker.handleHeartbeatResponse(empty)
+
+ // only reserve host1
+ Assert.assertEquals(
+ statusTracker.excludedWorkers.get(mock("host1")),
+ (StatusCode.WORKER_UNKNOWN, registerTime))
+
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host2")))
+
+ // add shutdown/excluded worker
+ val response1 =
+ buildResponse(Array("host0"), Array("host1", "host3"), Array("host4"),
Array.empty)
+ statusTracker.handleHeartbeatResponse(response1)
+
+ // test keep Unknown register time
+ Assert.assertEquals(
+ statusTracker.excludedWorkers.get(mock("host1")),
+ (StatusCode.WORKER_UNKNOWN, registerTime))
+ // test new added workers
+ Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host0")))
+ Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host3")))
+
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
+ Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))
+
+ // test re heartbeat with shutdown workers
+ val response2 = buildResponse(Array.empty, Array.empty, Array("host4"),
Array.empty)
statusTracker.handleHeartbeatResponse(response2)
+
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
+ Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))
+
+ // test remove
+ val workers = new util.HashSet[WorkerInfo]
+ workers.add(mock("host3"))
+ statusTracker.removeFromExcludedWorkers(workers)
+
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host3")))
+
+ // test register time elapsed
+ Thread.sleep(3000)
+ val response3 = buildResponse(Array.empty, Array("host5", "host6"),
Array.empty, Array.empty)
+ statusTracker.handleHeartbeatResponse(response3)
Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host1")))
+
+ // test available workers
+ Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(),
0)
+ val response4 = buildResponse(
+ Array.empty,
+ Array.empty,
+ Array.empty,
+ Array("host5", "host6", "host7", "host8"))
+ statusTracker.handleHeartbeatResponse(response4)
+
+ // availableWorkers wont update with excludedWorkers
+ // So before using them we hava to filter excludedWorkers
+ Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(),
4)
+ // available workers won't overwrite excluded workers
+ Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
+ Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host5")))
+ Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host6")))
+
+ // test re heartbeat with available workers
+ val response5 = buildResponse(Array.empty, Array.empty, Array.empty,
Array("host8", "host9"))
+ statusTracker.handleHeartbeatResponse(response5)
+ Assert.assertEquals(statusTracker.availableWorkersWithoutEndpoint.size(),
2)
+
Assert.assertFalse(statusTracker.availableWorkersWithoutEndpoint.contains(mock("host7")))
+
Assert.assertTrue(statusTracker.availableWorkersWithoutEndpoint.contains(mock("host8")))
+
Assert.assertTrue(statusTracker.availableWorkersWithoutEndpoint.contains(mock("host9")))
}
private def buildResponse(
excludedWorkerHosts: Array[String],
unknownWorkerHosts: Array[String],
- shuttingWorkerHosts: Array[String]): HeartbeatFromApplicationResponse = {
+ shuttingWorkerHosts: Array[String],
+ availableWorkerHosts: Array[String]): HeartbeatFromApplicationResponse =
{
val excludedWorkers = mockWorkers(excludedWorkerHosts)
val unknownWorkers = mockWorkers(unknownWorkerHosts)
val shuttingWorkers = mockWorkers(shuttingWorkerHosts)
+ val availableWorkers = mockWorkers(availableWorkerHosts)
HeartbeatFromApplicationResponse(
StatusCode.SUCCESS,
excludedWorkers,
unknownWorkers,
shuttingWorkers,
+ availableWorkers,
new util.ArrayList[Integer]())
}
diff --git a/common/src/main/proto/TransportMessages.proto
b/common/src/main/proto/TransportMessages.proto
index bc8da08ab..d5b129b0d 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -442,6 +442,7 @@ message PbHeartbeatFromApplication {
string requestId = 4;
repeated PbWorkerInfo needCheckedWorkerList = 5;
bool shouldResponse = 6;
+ bool needAvailableWorkers = 7;
}
message PbHeartbeatFromApplicationResponse {
@@ -450,6 +451,7 @@ message PbHeartbeatFromApplicationResponse {
repeated PbWorkerInfo unknownWorkers = 3;
repeated PbWorkerInfo shuttingWorkers = 4;
repeated int32 registeredShuffles = 5;
+ repeated PbWorkerInfo availableWorkers = 6;
}
message PbCheckQuota {
diff --git
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 9693146c3..97befab96 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -898,6 +898,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable
with Logging with Se
def clientReserveSlotsRetryWait: Long = get(CLIENT_RESERVE_SLOTS_RETRY_WAIT)
def clientRequestCommitFilesMaxRetries: Int =
get(CLIENT_COMMIT_FILE_REQUEST_MAX_RETRY)
def clientCommitFilesIgnoreExcludedWorkers: Boolean =
get(CLIENT_COMMIT_IGNORE_EXCLUDED_WORKERS)
+ def clientShuffleDynamicResourceEnabled: Boolean =
+ get(CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED)
def appHeartbeatTimeoutMs: Long = get(APPLICATION_HEARTBEAT_TIMEOUT)
def hdfsExpireDirsTimeoutMS: Long = get(HDFS_EXPIRE_DIRS_TIMEOUT)
def dfsExpireDirsTimeoutMS: Long = get(DFS_EXPIRE_DIRS_TIMEOUT)
@@ -4827,6 +4829,15 @@ object CelebornConf extends Logging {
.booleanConf
.createWithDefault(false)
+ val CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED: ConfigEntry[Boolean] =
+ buildConf("celeborn.client.shuffle.dynamicResourceEnabled")
+ .categories("client")
+ .version("0.6.0")
+ .doc("When enabled, the ChangePartitionManager will obtain candidate
workers from the availableWorkers pool “ +" +
+ "during heartbeats when worker resource change.")
+ .booleanConf
+ .createWithDefault(false)
+
val CLIENT_PUSH_STAGE_END_TIMEOUT: ConfigEntry[Long] =
buildConf("celeborn.client.push.stageEnd.timeout")
.withAlternative("celeborn.push.stageEnd.timeout")
diff --git
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index d0eb85b02..e5ea22fd4 100644
---
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -413,6 +413,7 @@ object ControlMessages extends Logging {
totalWritten: Long,
fileCount: Long,
needCheckedWorkerList: util.List[WorkerInfo],
+ needAvailableWorkers: Boolean,
override var requestId: String = ZERO_UUID,
shouldResponse: Boolean = false) extends MasterRequestMessage
@@ -421,6 +422,7 @@ object ControlMessages extends Logging {
excludedWorkers: util.List[WorkerInfo],
unknownWorkers: util.List[WorkerInfo],
shuttingWorkers: util.List[WorkerInfo],
+ availableWorkers: util.List[WorkerInfo],
registeredShuffles: util.List[Integer]) extends Message
case class CheckQuota(userIdentifier: UserIdentifier) extends Message
@@ -809,6 +811,7 @@ object ControlMessages extends Logging {
totalWritten,
fileCount,
needCheckedWorkerList,
+ needAvailableWorkers,
requestId,
shouldResponse) =>
val payload = PbHeartbeatFromApplication.newBuilder()
@@ -818,6 +821,7 @@ object ControlMessages extends Logging {
.setFileCount(fileCount)
.addAllNeedCheckedWorkerList(needCheckedWorkerList.asScala.map(
PbSerDeUtils.toPbWorkerInfo(_, true, true)).toList.asJava)
+ .setNeedAvailableWorkers(needAvailableWorkers)
.setShouldResponse(shouldResponse)
.build().toByteArray
new TransportMessage(MessageType.HEARTBEAT_FROM_APPLICATION, payload)
@@ -827,6 +831,7 @@ object ControlMessages extends Logging {
excludedWorkers,
unknownWorkers,
shuttingWorkers,
+ availableWorkers,
registeredShuffles) =>
val payload = PbHeartbeatFromApplicationResponse.newBuilder()
.setStatus(statusCode.getValue)
@@ -836,6 +841,8 @@ object ControlMessages extends Logging {
unknownWorkers.asScala.map(PbSerDeUtils.toPbWorkerInfo(_, true,
true)).toList.asJava)
.addAllShuttingWorkers(
shuttingWorkers.asScala.map(PbSerDeUtils.toPbWorkerInfo(_, true,
true)).toList.asJava)
+ .addAllAvailableWorkers(
+ availableWorkers.asScala.map(PbSerDeUtils.toPbWorkerInfo(_, true,
true)).toList.asJava)
.addAllRegisteredShuffles(registeredShuffles)
.build().toByteArray
new TransportMessage(MessageType.HEARTBEAT_FROM_APPLICATION_RESPONSE,
payload)
@@ -1207,6 +1214,7 @@ object ControlMessages extends Logging {
new util.ArrayList[WorkerInfo](
pbHeartbeatFromApplication.getNeedCheckedWorkerListList.asScala
.map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava),
+ pbHeartbeatFromApplication.getNeedAvailableWorkers,
pbHeartbeatFromApplication.getRequestId,
pbHeartbeatFromApplication.getShouldResponse)
@@ -1221,6 +1229,8 @@ object ControlMessages extends Logging {
.map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava,
pbHeartbeatFromApplicationResponse.getShuttingWorkersList.asScala
.map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava,
+ pbHeartbeatFromApplicationResponse.getAvailableWorkersList.asScala
+ .map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava,
pbHeartbeatFromApplicationResponse.getRegisteredShufflesList)
case CHECK_QUOTA_VALUE =>
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index ce309a13f..492bd8133 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -95,6 +95,7 @@ license: |
| celeborn.client.shuffle.compression.codec | LZ4 | false | The codec used to
compress shuffle data. By default, Celeborn provides three codecs: `lz4`,
`zstd`, `none`. `none` means that shuffle compression is disabled. Since Flink
version 1.17, zstd is supported for Flink shuffle client. | 0.3.0 |
celeborn.shuffle.compression.codec,remote-shuffle.job.compression.codec |
| celeborn.client.shuffle.compression.zstd.level | 1 | false | Compression
level for Zstd compression codec, its value should be an integer between -5 and
22. Increasing the compression level will result in better compression at the
expense of more CPU and memory. | 0.3.0 |
celeborn.shuffle.compression.zstd.level |
| celeborn.client.shuffle.decompression.lz4.xxhash.instance |
<undefined> | false | Decompression XXHash instance for Lz4. Available
options: JNI, JAVASAFE, JAVAUNSAFE. | 0.3.2 | |
+| celeborn.client.shuffle.dynamicResourceEnabled | false | false | When
enabled, the ChangePartitionManager will obtain candidate workers from the
availableWorkers pool “ +during heartbeats when worker resource change. | 0.6.0
| |
| celeborn.client.shuffle.expired.checkInterval | 60s | false | Interval for
client to check expired shuffles. | 0.3.0 |
celeborn.shuffle.expired.checkInterval |
| celeborn.client.shuffle.manager.port | 0 | false | Port used by the
LifecycleManager on the Driver. | 0.3.0 | celeborn.shuffle.manager.port |
| celeborn.client.shuffle.mapPartition.split.enabled | false | false | whether
to enable shuffle partition split. Currently, this only applies to
MapPartition. | 0.3.1 | |
diff --git
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
index 86baaa921..bec47890e 100644
---
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
+++
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
@@ -407,6 +407,7 @@ private[celeborn] class Master(
totalWritten,
fileCount,
needCheckedWorkerList,
+ needAvailableWorkers,
requestId,
shouldResponse) =>
logDebug(s"Received heartbeat from app $appId")
@@ -419,6 +420,7 @@ private[celeborn] class Master(
totalWritten,
fileCount,
needCheckedWorkerList,
+ needAvailableWorkers,
requestId,
shouldResponse))
@@ -1113,6 +1115,7 @@ private[celeborn] class Master(
totalWritten: Long,
fileCount: Long,
needCheckedWorkerList: util.List[WorkerInfo],
+ needAvailableWorkers: Boolean,
requestId: String,
shouldResponse: Boolean): Unit = {
statusSystem.handleAppHeartbeat(
@@ -1126,6 +1129,12 @@ private[celeborn] class Master(
if (shouldResponse) {
// UserResourceConsumption and DiskInfo are eliminated from WorkerInfo
// during serialization of HeartbeatFromApplicationResponse
+ var availableWorksSentToClient = new util.ArrayList[WorkerInfo]()
+ if (needAvailableWorkers) {
+ availableWorksSentToClient = new util.ArrayList[WorkerInfo](
+ statusSystem.workers.asScala.filter(worker =>
+ statusSystem.isWorkerAvailable(worker)).asJava)
+ }
var appRelatedShuffles =
statusSystem.registeredAppAndShuffles.getOrDefault(appId,
Collections.emptySet())
context.reply(HeartbeatFromApplicationResponse(
@@ -1135,6 +1144,7 @@ private[celeborn] class Master(
needCheckedWorkerList,
new util.ArrayList[WorkerInfo](
(statusSystem.shutdownWorkers.asScala ++
statusSystem.decommissionWorkers.asScala).asJava),
+ availableWorksSentToClient,
new util.ArrayList(appRelatedShuffles)))
} else {
context.reply(OneWayMessageResponse)
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala
new file mode 100644
index 000000000..bee3b61ec
--- /dev/null
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.client
+
+import java.util
+
+import scala.collection.JavaConverters.mapAsScalaMapConverter
+
+import org.apache.celeborn.client.{ChangePartitionManager,
ChangePartitionRequest, LifecycleManager, WithShuffleClientSuite}
+import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo,
WorkerInfo}
+import org.apache.celeborn.common.protocol.message.StatusCode
+import org.apache.celeborn.common.util.{CelebornExitKind, JavaUtils}
+import org.apache.celeborn.service.deploy.MiniClusterFeature
+
+class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite
+ with MiniClusterFeature {
+ celebornConf
+ .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "false")
+ .set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "256K")
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ val testConf = Map(
+ s"${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}" -> "3")
+ val (master, _) = setupMiniClusterWithRandomPorts(testConf, testConf,
workerNum = 1)
+ celebornConf.set(
+ CelebornConf.MASTER_ENDPOINTS.key,
+ master.conf.get(CelebornConf.MASTER_ENDPOINTS.key))
+ }
+
+ test("test changePartition with available workers") {
+ val shuffleId = nextShuffleId
+ val conf = celebornConf.clone
+ conf.set(CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key, "3")
+ .set(CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED.key, "true")
+ .set(CelebornConf.CLIENT_BATCH_HANDLE_CHANGE_PARTITION_ENABLED.key,
"false")
+
+ val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf)
+ val changePartitionManager: ChangePartitionManager =
+ new ChangePartitionManager(conf, lifecycleManager)
+ val ids = new util.ArrayList[Integer](10)
+ 0 until 10 foreach {
+ ids.add(_)
+ }
+ val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId,
ids)
+ assert(res.status == StatusCode.SUCCESS)
+ assert(res.workerResource.keySet().size() == 1)
+
+ lifecycleManager.setupEndpoints(
+ res.workerResource.keySet(),
+ shuffleId,
+ new ShuffleFailedWorkers())
+
+ val reserveSlotsSuccess = lifecycleManager.reserveSlotsWithRetry(
+ shuffleId,
+ new util.HashSet(res.workerResource.keySet()),
+ res.workerResource,
+ updateEpoch = false)
+
+ val slots = res.workerResource
+ val candidatesWorkers = new util.HashSet(slots.keySet())
+ if (reserveSlotsSuccess) {
+ val allocatedWorkers =
+ JavaUtils.newConcurrentHashMap[WorkerInfo,
ShufflePartitionLocationInfo]()
+ res.workerResource.asScala.foreach {
+ case (workerInfo, (primaryLocations, replicaLocations)) =>
+ val partitionLocationInfo = new ShufflePartitionLocationInfo()
+ partitionLocationInfo.addPrimaryPartitions(primaryLocations)
+ partitionLocationInfo.addReplicaPartitions(replicaLocations)
+ allocatedWorkers.put(workerInfo, partitionLocationInfo)
+ }
+ lifecycleManager.shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
+
lifecycleManager.workerStatusTracker.addWorkersWithEndpoint(candidatesWorkers)
+ }
+ assert(lifecycleManager.workerSnapshots(shuffleId).size() == 1)
+
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size()
== 1)
+
assert(lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size()
== 0)
+
+ // total workerNum is 1 + 2 = 3 now
+ setUpWorkers(workerConfForAdding, 2)
+ // longer than APPLICATION_HEARTBEAT_INTERVAL 10s
+ Thread.sleep(15000)
+ assert(workerInfos.size == 3)
+
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size()
== 1)
+
assert(lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size()
== 2)
+
+ 0 until 10 foreach { partitionId: Int =>
+ val req = ChangePartitionRequest(
+ null,
+ shuffleId,
+ partitionId,
+ -1,
+ null,
+ None)
+ changePartitionManager.changePartitionRequests.computeIfAbsent(
+ shuffleId,
+ changePartitionManager.rpcContextRegisterFunc)
+ changePartitionManager.handleRequestPartitions(
+ shuffleId,
+ Array(req),
+ lifecycleManager.commitManager.isSegmentGranularityVisible(shuffleId))
+ }
+ assert(
+
lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size() +
lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size() == 3)
+
+ assert(lifecycleManager.workerSnapshots(shuffleId).size() > 1)
+ assert(
+ lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size()
> 1)
+
+ // shut down workers test
+ val workerInfoList = workerInfos.toList
+ 0 until 2 foreach { index =>
+ val (worker, _) = workerInfoList(index)
+ worker.stop(CelebornExitKind.EXIT_IMMEDIATELY)
+ worker.rpcEnv.shutdown()
+ // Workers in miniClusterFeature wont update status with master through
heartbeat.
+ // So update status manually.
+ masterInfo._1.statusSystem.excludedWorkers.add(worker.workerInfo)
+ workerInfos.remove(worker)
+ }
+
+ Thread.sleep(15000)
+ assert(
+ lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size()
+ lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size()
== 1)
+
+ lifecycleManager.stop()
+ }
+
+ test("test changePartition without available workers") {
+ val shuffleId = nextShuffleId
+ val conf = celebornConf.clone
+ conf.set(CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key, "3")
+ .set(CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED.key, "false")
+ .set(CelebornConf.CLIENT_BATCH_HANDLE_CHANGE_PARTITION_ENABLED.key,
"false")
+
+ val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf)
+ val changePartitionManager: ChangePartitionManager =
+ new ChangePartitionManager(conf, lifecycleManager)
+ val ids = new util.ArrayList[Integer](10)
+ 0 until 10 foreach {
+ ids.add(_)
+ }
+ val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId,
ids)
+ assert(res.status == StatusCode.SUCCESS)
+
+ // workerNum is 1 (after 1 add 2 and stop 2)
+ val workerNum = res.workerResource.keySet().size()
+ assert(workerNum == 1)
+
+ lifecycleManager.setupEndpoints(
+ res.workerResource.keySet(),
+ shuffleId,
+ new ShuffleFailedWorkers())
+
+ val reserveSlotsSuccess = lifecycleManager.reserveSlotsWithRetry(
+ shuffleId,
+ new util.HashSet(res.workerResource.keySet()),
+ res.workerResource,
+ updateEpoch = false)
+
+ val slots = res.workerResource
+ val candidatesWorkers = new util.HashSet(slots.keySet())
+ if (reserveSlotsSuccess) {
+ val allocatedWorkers =
+ JavaUtils.newConcurrentHashMap[WorkerInfo,
ShufflePartitionLocationInfo]()
+ res.workerResource.asScala.foreach {
+ case (workerInfo, (primaryLocations, replicaLocations)) =>
+ val partitionLocationInfo = new ShufflePartitionLocationInfo()
+ partitionLocationInfo.addPrimaryPartitions(primaryLocations)
+ partitionLocationInfo.addReplicaPartitions(replicaLocations)
+ allocatedWorkers.put(workerInfo, partitionLocationInfo)
+ }
+ lifecycleManager.shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
+
lifecycleManager.workerStatusTracker.addWorkersWithEndpoint(candidatesWorkers)
+ }
+ assert(lifecycleManager.workerSnapshots(shuffleId).size() == workerNum)
+
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size()
== workerNum)
+
+ // total workerNum is 1 + 2 = 3 now
+ setUpWorkers(workerConfForAdding, 2)
+ // longer than APPLICATION_HEARTBEAT_INTERVAL 10s
+ Thread.sleep(15000)
+ assert(workerInfos.size == 3)
+
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size()
== workerNum)
+
assert(lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size()
== 0)
+
+ 0 until 10 foreach { partitionId: Int =>
+ val req = ChangePartitionRequest(
+ null,
+ shuffleId,
+ partitionId,
+ -1,
+ null,
+ None)
+ changePartitionManager.changePartitionRequests.computeIfAbsent(
+ shuffleId,
+ changePartitionManager.rpcContextRegisterFunc)
+ changePartitionManager.handleRequestPartitions(
+ shuffleId,
+ Array(req),
+ lifecycleManager.commitManager.isSegmentGranularityVisible(shuffleId))
+ }
+ logInfo(s"reallocated worker num: ${res.workerResource.keySet().size()};
workerInfo: ${res.workerResource.keySet()}")
+ assert(lifecycleManager.workerSnapshots(shuffleId).size() == workerNum)
+
assert(lifecycleManager.workerStatusTracker.availableWorkersWithEndpoint.size()
== workerNum)
+
assert(lifecycleManager.workerStatusTracker.availableWorkersWithoutEndpoint.size()
== 0)
+
+ lifecycleManager.stop()
+ }
+
+ override def afterAll(): Unit = {
+ logInfo("all test complete , stop celeborn mini cluster")
+ shutdownMiniCluster()
+ }
+}
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
index 90192e283..1f7769140 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
@@ -58,7 +58,10 @@ class LifecycleManagerCommitFilesSuite extends
WithShuffleClientSuite with MiniC
assert(res.status == StatusCode.SUCCESS)
assert(res.workerResource.keySet().size() == 3)
- lifecycleManager.setupEndpoints(res.workerResource, shuffleId, new
ShuffleFailedWorkers())
+ lifecycleManager.setupEndpoints(
+ res.workerResource.keySet(),
+ shuffleId,
+ new ShuffleFailedWorkers())
lifecycleManager.reserveSlotsWithRetry(
shuffleId,
@@ -108,7 +111,10 @@ class LifecycleManagerCommitFilesSuite extends
WithShuffleClientSuite with MiniC
assert(res.status == StatusCode.SUCCESS)
assert(res.workerResource.keySet().size() == 3)
- lifecycleManager.setupEndpoints(res.workerResource, shuffleId, new
ShuffleFailedWorkers())
+ lifecycleManager.setupEndpoints(
+ res.workerResource.keySet(),
+ shuffleId,
+ new ShuffleFailedWorkers())
lifecycleManager.reserveSlotsWithRetry(
shuffleId,
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala
index 1fc639d60..e83268202 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala
@@ -56,7 +56,10 @@ class LifecycleManagerDestroySlotsSuite extends
WithShuffleClientSuite with Mini
assert(res.status == StatusCode.SUCCESS)
assert(res.workerResource.keySet().size() == 3)
- lifecycleManager.setupEndpoints(res.workerResource, shuffleId, new
ShuffleFailedWorkers())
+ lifecycleManager.setupEndpoints(
+ res.workerResource.keySet(),
+ shuffleId,
+ new ShuffleFailedWorkers())
lifecycleManager.reserveSlotsWithRetry(
shuffleId,
@@ -95,7 +98,10 @@ class LifecycleManagerDestroySlotsSuite extends
WithShuffleClientSuite with Mini
assert(res.status == StatusCode.SUCCESS)
assert(res.workerResource.keySet().size() == 3)
- lifecycleManager.setupEndpoints(res.workerResource, shuffleId, new
ShuffleFailedWorkers())
+ lifecycleManager.setupEndpoints(
+ res.workerResource.keySet(),
+ shuffleId,
+ new ShuffleFailedWorkers())
lifecycleManager.reserveSlotsWithRetry(
shuffleId,
@@ -134,7 +140,10 @@ class LifecycleManagerDestroySlotsSuite extends
WithShuffleClientSuite with Mini
assert(res.status == StatusCode.SUCCESS)
assert(res.workerResource.keySet().size() == 3)
- lifecycleManager.setupEndpoints(res.workerResource, shuffleId, new
ShuffleFailedWorkers())
+ lifecycleManager.setupEndpoints(
+ res.workerResource.keySet(),
+ shuffleId,
+ new ShuffleFailedWorkers())
lifecycleManager.reserveSlotsWithRetry(
shuffleId,
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSetupEndpointSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSetupEndpointSuite.scala
index 8fee9cbb9..9ad455756 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSetupEndpointSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSetupEndpointSuite.scala
@@ -52,7 +52,7 @@ class LifecycleManagerSetupEndpointSuite extends
WithShuffleClientSuite with Min
assert(res.workerResource.keySet().size() == 3)
val connectFailedWorkers = new ShuffleFailedWorkers()
- lifecycleManager.setupEndpoints(res.workerResource, 0,
connectFailedWorkers)
+ lifecycleManager.setupEndpoints(res.workerResource.keySet(), 0,
connectFailedWorkers)
assert(connectFailedWorkers.isEmpty)
lifecycleManager.stop()
@@ -73,7 +73,7 @@ class LifecycleManagerSetupEndpointSuite extends
WithShuffleClientSuite with Min
firstWorker.rpcEnv.shutdown()
val connectFailedWorkers = new ShuffleFailedWorkers()
- lifecycleManager.setupEndpoints(res.workerResource, 0,
connectFailedWorkers)
+ lifecycleManager.setupEndpoints(res.workerResource.keySet(), 0,
connectFailedWorkers)
assert(connectFailedWorkers.size() == 1)
assert(connectFailedWorkers.keySet().asScala.head ==
firstWorker.workerInfo)
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
index ea6eed601..4e3b7e6ff 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
@@ -32,21 +32,43 @@ class RetryReviveTest extends AnyFunSuite
override def beforeAll(): Unit = {
logInfo("test initialized , setup celeborn mini cluster")
- setupMiniClusterWithRandomPorts()
}
- override def beforeEach(): Unit = {
- ShuffleClient.reset()
- }
+ override def beforeEach(): Unit = {}
override def afterEach(): Unit = {
System.gc()
}
test("celeborn spark integration test - retry revive as configured times") {
+ setupMiniClusterWithRandomPorts()
+ ShuffleClient.reset()
+ val sparkConf = new SparkConf()
+ .set(s"spark.${CelebornConf.TEST_CLIENT_RETRY_REVIVE.key}", "true")
+ .set(s"spark.${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}", "3")
+ .setAppName("celeborn-demo").setMaster("local[2]")
+ val ss = SparkSession.builder()
+ .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+ .getOrCreate()
+ val result = ss.sparkContext.parallelize(1 to 1000, 2)
+ .map { i => (i, Range(1, 1000).mkString(",")) }.groupByKey(4).collect()
+ assert(result.size == 1000)
+ ss.stop()
+ }
+
+ test(
+ "celeborn spark integration test - e2e test retry revive with available
workers from heartbeat") {
+ val testConf = Map(
+ s"${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}" -> "3",
+ s"${CelebornConf.MASTER_SLOT_ASSIGN_EXTRA_SLOTS.key}" -> "0")
+ setupMiniClusterWithRandomPorts(testConf)
+ ShuffleClient.reset()
val sparkConf = new SparkConf()
.set(s"spark.${CelebornConf.TEST_CLIENT_RETRY_REVIVE.key}", "true")
.set(s"spark.${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}", "3")
+ .set(s"spark.${CelebornConf.CLIENT_SLOT_ASSIGN_MAX_WORKERS.key}", "1")
+
.set(s"spark.${CelebornConf.CLIENT_SHUFFLE_DYNAMIC_RESOURCE_ENABLED.key}",
"true")
+ .set(s"spark.${CelebornConf.MASTER_SLOT_ASSIGN_EXTRA_SLOTS.key}", "0")
.setAppName("celeborn-demo").setMaster("local[2]")
val ss = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
diff --git
a/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
b/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
index fb6bd67f6..7c1280f13 100644
---
a/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
+++
b/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
@@ -37,6 +37,7 @@ trait MiniClusterFeature extends Logging {
var masterInfo: (Master, Thread) = _
val workerInfos = new mutable.HashMap[Worker, Thread]()
+ var workerConfForAdding: Map[String, String] = _
class RunnerWrap[T](code: => T) extends Thread {
@@ -71,6 +72,7 @@ trait MiniClusterFeature extends Logging {
workerConf
logInfo(
s"generated configuration. Master conf = $finalMasterConf, worker
conf = $finalWorkerConf")
+ workerConfForAdding = finalWorkerConf
val (m, w) =
setUpMiniCluster(masterConf = finalMasterConf, workerConf =
finalWorkerConf, workerNum)
master = m
@@ -148,10 +150,7 @@ trait MiniClusterFeature extends Logging {
}
}
- private def setUpMiniCluster(
- masterConf: Map[String, String] = null,
- workerConf: Map[String, String] = null,
- workerNum: Int = 3): (Master, collection.Set[Worker]) = {
+ def setUpMaster(masterConf: Map[String, String] = null): Master = {
val timeout = 30000
val master = createMaster(masterConf)
val masterStartedSignal = Array(false)
@@ -176,7 +175,13 @@ trait MiniClusterFeature extends Logging {
throw new BindException("cannot start master rpc endpoint")
}
}
+ master
+ }
+ def setUpWorkers(
+ workerConf: Map[String, String] = null,
+ workerNum: Int = 3): collection.Set[Worker] = {
+ val timeout = 30000
val workers = new Array[Worker](workerNum)
val flagUpdateLock = new ReentrantLock()
val threads = (1 to workerNum).map { i =>
@@ -239,7 +244,14 @@ trait MiniClusterFeature extends Logging {
}
}
}
- (master, workerInfos.keySet)
+ workerInfos.keySet
+ }
+
+ private def setUpMiniCluster(
+ masterConf: Map[String, String] = null,
+ workerConf: Map[String, String] = null,
+ workerNum: Int = 3): (Master, collection.Set[Worker]) = {
+ (setUpMaster(masterConf), setUpWorkers(workerConf, workerNum))
}
def shutdownMiniCluster(): Unit = {