This is an automated email from the ASF dual-hosted git repository.

angerszhuuuu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e4dec96 [CELEBORN-21][REFACTOR] Extract revive related logical from 
LifecycleManager (#1024)
1e4dec96 is described below

commit 1e4dec96b98dd3e70f38c922de68617002bad0ce
Author: Angerszhuuuu <[email protected]>
AuthorDate: Mon Dec 5 17:05:17 2022 +0800

    [CELEBORN-21][REFACTOR] Extract revive related logical from 
LifecycleManager (#1024)
    
    * [CELEBORN-21][REFACTOR] Extract revive related logical from 
LifecycleManager
---
 .../celeborn/client/ChangePartitionManager.scala   | 331 +++++++++++++++++++++
 .../apache/celeborn/client/LifecycleManager.scala  | 321 ++------------------
 2 files changed, 361 insertions(+), 291 deletions(-)

diff --git 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
new file mode 100644
index 00000000..ea1de6af
--- /dev/null
+++ 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client
+
+import java.util
+import java.util.{Set => JSet}
+import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, 
ScheduledFuture, TimeUnit}
+
+import scala.collection.JavaConverters._
+import scala.concurrent.duration.DurationInt
+
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.meta.WorkerInfo
+import org.apache.celeborn.common.protocol.PartitionLocation
+import 
org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource
+import org.apache.celeborn.common.protocol.message.StatusCode
+import org.apache.celeborn.common.util.{ThreadUtils, Utils}
+
+case class ChangePartitionRequest(
+    context: RequestLocationCallContext,
+    applicationId: String,
+    shuffleId: Int,
+    partitionId: Int,
+    epoch: Int,
+    oldPartition: PartitionLocation,
+    causes: Option[StatusCode])
+
+class ChangePartitionManager(
+    conf: CelebornConf,
+    lifecycleManager: LifecycleManager) extends Logging {
+
+  private val pushReplicateEnabled = conf.pushReplicateEnabled
+  // shuffleId -> (partitionId -> set of ChangePartition)
+  private val changePartitionRequests =
+    new ConcurrentHashMap[Int, ConcurrentHashMap[Integer, 
JSet[ChangePartitionRequest]]]()
+  // shuffleId -> set of partition id
+  private val inBatchPartitions = new ConcurrentHashMap[Int, JSet[Integer]]()
+
+  private val batchHandleChangePartitionEnabled = 
conf.batchHandleChangePartitionEnabled
+  private val batchHandleChangePartitionExecutors = 
ThreadUtils.newDaemonCachedThreadPool(
+    "rss-lifecycle-manager-change-partition-executor",
+    conf.batchHandleChangePartitionNumThreads)
+  private val batchHandleChangePartitionRequestInterval =
+    conf.batchHandleChangePartitionRequestInterval
+  private val batchHandleChangePartitionSchedulerThread: 
Option[ScheduledExecutorService] =
+    if (batchHandleChangePartitionEnabled) {
+      Some(ThreadUtils.newDaemonSingleThreadScheduledExecutor(
+        "rss-lifecycle-manager-change-partition-scheduler"))
+    } else {
+      None
+    }
+
+  private var batchHandleChangePartition: Option[ScheduledFuture[_]] = _
+
+  def start(): Unit = {
+    batchHandleChangePartition = batchHandleChangePartitionSchedulerThread.map 
{
+      // noinspection ConvertExpressionToSAM
+      _.scheduleAtFixedRate(
+        new Runnable {
+          override def run(): Unit = {
+            try {
+              changePartitionRequests.asScala.foreach { case (shuffleId, 
requests) =>
+                requests.synchronized {
+                  batchHandleChangePartitionExecutors.submit {
+                    new Runnable {
+                      override def run(): Unit = {
+                        // For each partition only need handle one request
+                        val distinctPartitions = requests.asScala.filter { 
case (partitionId, _) =>
+                          
!inBatchPartitions.get(shuffleId).contains(partitionId)
+                        }.map { case (partitionId, request) =>
+                          inBatchPartitions.get(shuffleId).add(partitionId)
+                          request.asScala.toArray.maxBy(_.epoch)
+                        }.toArray
+                        if (distinctPartitions.nonEmpty) {
+                          batchHandleRequestPartitions(
+                            distinctPartitions.head.applicationId,
+                            shuffleId,
+                            distinctPartitions)
+                        }
+                      }
+                    }
+                  }
+                }
+              }
+            } catch {
+              case e: InterruptedException =>
+                logError("Partition split scheduler thread is shutting down, 
detail: ", e)
+                throw e
+            }
+          }
+        },
+        0,
+        batchHandleChangePartitionRequestInterval,
+        TimeUnit.MILLISECONDS)
+    }
+  }
+
+  def stop(): Unit = {
+    batchHandleChangePartition.foreach(_.cancel(true))
+    batchHandleChangePartitionSchedulerThread.foreach(ThreadUtils.shutdown(_, 
800.millis))
+  }
+
+  private val rpcContextRegisterFunc =
+    new util.function.Function[
+      Int,
+      ConcurrentHashMap[Integer, util.Set[ChangePartitionRequest]]]() {
+      override def apply(s: Int): ConcurrentHashMap[Integer, 
util.Set[ChangePartitionRequest]] =
+        new ConcurrentHashMap()
+    }
+
+  private val inBatchShuffleIdRegisterFunc = new util.function.Function[Int, 
util.Set[Integer]]() {
+    override def apply(s: Int): util.Set[Integer] = new util.HashSet[Integer]()
+  }
+
+  def handleRequestPartitionLocation(
+      context: RequestLocationCallContext,
+      applicationId: String,
+      shuffleId: Int,
+      partitionId: Int,
+      oldEpoch: Int,
+      oldPartition: PartitionLocation,
+      cause: Option[StatusCode] = None): Unit = {
+
+    val changePartition = ChangePartitionRequest(
+      context,
+      applicationId,
+      shuffleId,
+      partitionId,
+      oldEpoch,
+      oldPartition,
+      cause)
+    // check if there exists request for the partition, if do just register
+    val requests = changePartitionRequests.computeIfAbsent(shuffleId, 
rpcContextRegisterFunc)
+    inBatchPartitions.computeIfAbsent(shuffleId, inBatchShuffleIdRegisterFunc)
+
+    lifecycleManager.registerCommitPartition(applicationId, shuffleId, 
oldPartition, cause)
+
+    requests.synchronized {
+      if (requests.containsKey(partitionId)) {
+        requests.get(partitionId).add(changePartition)
+        logTrace(s"[handleRequestPartitionLocation] For $shuffleId, request 
for same partition" +
+          s"$partitionId-$oldEpoch exists, register context.")
+        return
+      } else {
+        // If new slot for the partition has been allocated, reply and return.
+        // Else register and allocate for it.
+        getLatestPartition(shuffleId, partitionId, oldEpoch).foreach { 
latestLoc =>
+          context.reply(StatusCode.SUCCESS, Some(latestLoc))
+          logDebug(s"New partition found, old partition $partitionId-$oldEpoch 
return it." +
+            s" shuffleId: $shuffleId $latestLoc")
+          return
+        }
+        val set = new util.HashSet[ChangePartitionRequest]()
+        set.add(changePartition)
+        requests.put(partitionId, set)
+      }
+    }
+    if (!batchHandleChangePartitionEnabled) {
+      batchHandleRequestPartitions(applicationId, shuffleId, 
Array(changePartition))
+    }
+  }
+
+  private def getLatestPartition(
+      shuffleId: Int,
+      partitionId: Int,
+      epoch: Int): Option[PartitionLocation] = {
+    val map = lifecycleManager.latestPartitionLocation.get(shuffleId)
+    if (map != null) {
+      val loc = map.get(partitionId)
+      if (loc != null && loc.getEpoch > epoch) {
+        return Some(loc)
+      }
+    }
+    None
+  }
+
+  def batchHandleRequestPartitions(
+      applicationId: String,
+      shuffleId: Int,
+      changePartitions: Array[ChangePartitionRequest]): Unit = {
+    val requestsMap = changePartitionRequests.get(shuffleId)
+
+    val changes = changePartitions.map { change =>
+      s"${change.shuffleId}-${change.partitionId}-${change.epoch}"
+    }.mkString("[", ",", "]")
+    logWarning(s"Batch handle change partition for $applicationId of $changes")
+
+    // Blacklist all failed workers
+    if (changePartitions.exists(_.causes.isDefined)) {
+      changePartitions.filter(_.causes.isDefined).foreach { changePartition =>
+        lifecycleManager.blacklistPartition(
+          shuffleId,
+          changePartition.oldPartition,
+          changePartition.causes.get)
+      }
+    }
+
+    // remove together to reduce lock time
+    def replySuccess(locations: Array[PartitionLocation]): Unit = {
+      requestsMap.synchronized {
+        locations.map { location =>
+          if (batchHandleChangePartitionEnabled) {
+            inBatchPartitions.get(shuffleId).remove(location.getId)
+          }
+          // Here one partition id can be remove more than once,
+          // so need to filter null result before reply.
+          location -> Option(requestsMap.remove(location.getId))
+        }
+      }.foreach { case (newLocation, requests) =>
+        requests.map(_.asScala.toList.foreach(_.context.reply(
+          StatusCode.SUCCESS,
+          Option(newLocation))))
+      }
+    }
+
+    // remove together to reduce lock time
+    def replyFailure(status: StatusCode): Unit = {
+      requestsMap.synchronized {
+        changePartitions.map { changePartition =>
+          if (batchHandleChangePartitionEnabled) {
+            
inBatchPartitions.get(shuffleId).remove(changePartition.partitionId)
+          }
+          Option(requestsMap.remove(changePartition.partitionId))
+        }
+      }.foreach { requests =>
+        requests.map(_.asScala.toList.foreach(_.context.reply(status, None)))
+      }
+    }
+
+    // Get candidate worker that not in blacklist of shuffleId
+    val candidates =
+      lifecycleManager
+        .workerSnapshots(shuffleId)
+        .keySet()
+        .asScala
+        .filter(w => !lifecycleManager.blacklist.keySet().contains(w))
+        .toList
+    if (candidates.size < 1 || (pushReplicateEnabled && candidates.size < 2)) {
+      logError("[Update partition] failed for not enough candidates for 
revive.")
+      replyFailure(StatusCode.SLOT_NOT_AVAILABLE)
+      return
+    }
+
+    // PartitionSplit all contains oldPartition
+    val newlyAllocatedLocations =
+      
reallocateChangePartitionRequestSlotsFromCandidates(changePartitions.toList, 
candidates)
+
+    if (!lifecycleManager.registeredShuffle.contains(shuffleId)) {
+      logError(s"[handleChangePartition] shuffle $shuffleId not registered!")
+      replyFailure(StatusCode.SHUFFLE_NOT_REGISTERED)
+      return
+    }
+
+    if (lifecycleManager.stageEndShuffleSet.contains(shuffleId)) {
+      logError(s"[handleChangePartition] shuffle $shuffleId already ended!")
+      replyFailure(StatusCode.STAGE_ENDED)
+      return
+    }
+
+    if (!lifecycleManager.reserveSlotsWithRetry(
+        applicationId,
+        shuffleId,
+        new util.HashSet(candidates.toSet.asJava),
+        newlyAllocatedLocations)) {
+      logError(s"[Update partition] failed for $shuffleId.")
+      replyFailure(StatusCode.RESERVE_SLOTS_FAILED)
+      return
+    }
+
+    val newMasterLocations =
+      newlyAllocatedLocations.asScala.flatMap {
+        case (workInfo, (masterLocations, slaveLocations)) =>
+          // Add all re-allocated slots to worker snapshots.
+          lifecycleManager.workerSnapshots(shuffleId).asScala
+            .get(workInfo)
+            .foreach { partitionLocationInfo =>
+              partitionLocationInfo.addMasterPartitions(shuffleId.toString, 
masterLocations)
+              lifecycleManager.updateLatestPartitionLocations(shuffleId, 
masterLocations)
+              partitionLocationInfo.addSlavePartitions(shuffleId.toString, 
slaveLocations)
+            }
+          // partition location can be null when call reserveSlotsWithRetry().
+          val locations = (masterLocations.asScala ++ 
slaveLocations.asScala.map(_.getPeer))
+            .distinct.filter(_ != null)
+          if (locations.nonEmpty) {
+            val changes = locations.map { partition =>
+              s"(partition ${partition.getId} epoch from ${partition.getEpoch 
- 1} to ${partition.getEpoch})"
+            }.mkString("[", ", ", "]")
+            logDebug(s"[Update partition] success for " +
+              s"shuffle ${Utils.makeShuffleKey(applicationId, shuffleId)}, 
succeed partitions: " +
+              s"$changes.")
+          }
+          locations
+      }
+    replySuccess(newMasterLocations.toArray)
+  }
+
+  private def reallocateChangePartitionRequestSlotsFromCandidates(
+      changePartitionRequests: List[ChangePartitionRequest],
+      candidates: List[WorkerInfo]): WorkerResource = {
+    val slots = new WorkerResource()
+    changePartitionRequests.foreach { partition =>
+      lifecycleManager.allocateFromCandidates(
+        partition.partitionId,
+        partition.epoch,
+        candidates,
+        slots)
+    }
+    slots
+  }
+
+  def removeExpiredShuffle(shuffleId: Int): Unit = {
+    changePartitionRequests.remove(shuffleId)
+    inBatchPartitions.remove(shuffleId)
+  }
+}
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 08ee5c6f..322806a9 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -63,19 +63,19 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
   private val rpcCacheConcurrencyLevel = conf.rpcCacheConcurrencyLevel
   private val rpcCacheExpireTime = conf.rpcCacheExpireTime
 
-  private val registeredShuffle = ConcurrentHashMap.newKeySet[Int]()
+  val registeredShuffle = ConcurrentHashMap.newKeySet[Int]()
   private val shuffleMapperAttempts = new ConcurrentHashMap[Int, Array[Int]]()
   private val reducerFileGroupsMap =
     new ConcurrentHashMap[Int, Array[Array[PartitionLocation]]]()
   private val dataLostShuffleSet = ConcurrentHashMap.newKeySet[Int]()
-  private val stageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
+  val stageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
   private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
   // maintain each shuffle's map relation of WorkerInfo and partition location
   private val shuffleAllocatedWorkers = {
     new ConcurrentHashMap[Int, ConcurrentHashMap[WorkerInfo, 
PartitionLocationInfo]]()
   }
   // shuffle id -> (partitionId -> newest PartitionLocation)
-  private val latestPartitionLocation =
+  val latestPartitionLocation =
     new ConcurrentHashMap[Int, ConcurrentHashMap[Int, PartitionLocation]]()
   private val userIdentifier: UserIdentifier = 
IdentityProvider.instantiate(conf).provide()
   // noinspection UnstableApiUsage
@@ -99,34 +99,19 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       }
     }
 
-  private def updateLatestPartitionLocations(
+  def updateLatestPartitionLocations(
       shuffleId: Int,
       locations: util.List[PartitionLocation]): Unit = {
     val map = latestPartitionLocation.computeIfAbsent(shuffleId, newMapFunc)
     locations.asScala.foreach(location => map.put(location.getId, location))
   }
 
-  case class ChangePartitionRequest(
-      context: RequestLocationCallContext,
-      applicationId: String,
-      shuffleId: Int,
-      partitionId: Int,
-      epoch: Int,
-      oldPartition: PartitionLocation,
-      causes: Option[StatusCode])
-
   case class RegisterCallContext(context: RpcCallContext, partitionId: Int = 
-1) {
     def reply(response: PbRegisterShuffleResponse) = {
       context.reply(response)
     }
   }
 
-  // shuffleId -> (partitionId -> set of ChangePartition)
-  private val changePartitionRequests =
-    new ConcurrentHashMap[Int, ConcurrentHashMap[Integer, 
util.Set[ChangePartitionRequest]]]()
-  // shuffleId -> set of partition id
-  private val inBatchPartitions = new ConcurrentHashMap[Int, 
util.Set[Integer]]()
-
   case class CommitPartitionRequest(
       applicationId: String,
       shuffleId: Int,
@@ -147,13 +132,27 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
 
   // shuffle id -> ShuffleCommittedInfo
   private val committedPartitionInfo = new ConcurrentHashMap[Int, 
ShuffleCommittedInfo]()
+  def registerCommitPartition(
+      applicationId: String,
+      shuffleId: Int,
+      partition: PartitionLocation,
+      cause: Option[StatusCode]): Unit = {
+    // handle hard split
+    if (batchHandleCommitPartitionEnabled && cause.isDefined && cause.get == 
StatusCode.HARD_SPLIT) {
+      val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
+      shuffleCommittedInfo.synchronized {
+        shuffleCommittedInfo.commitPartitionRequests
+          .add(CommitPartitionRequest(applicationId, shuffleId, partition))
+      }
+    }
+  }
 
   // register shuffle request waiting for response
   private val registeringShuffleRequest =
     new ConcurrentHashMap[Int, util.Set[RegisterCallContext]]()
 
   // blacklist
-  private val blacklist = new ConcurrentHashMap[WorkerInfo, (StatusCode, 
Long)]()
+  val blacklist = new ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]()
 
   // Threads
   private val forwardMessageThread =
@@ -161,20 +160,6 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
   private var checkForShuffleRemoval: ScheduledFuture[_] = _
   private var getBlacklist: ScheduledFuture[_] = _
 
-  private val batchHandleChangePartitionEnabled = 
conf.batchHandleChangePartitionEnabled
-  private val batchHandleChangePartitionExecutors = 
ThreadUtils.newDaemonCachedThreadPool(
-    "rss-lifecycle-manager-change-partition-executor",
-    conf.batchHandleChangePartitionNumThreads)
-  private val batchHandleChangePartitionRequestInterval =
-    conf.batchHandleChangePartitionRequestInterval
-  private val batchHandleChangePartitionSchedulerThread: 
Option[ScheduledExecutorService] =
-    if (batchHandleChangePartitionEnabled) {
-      Some(ThreadUtils.newDaemonSingleThreadScheduledExecutor(
-        "rss-lifecycle-manager-change-partition-scheduler"))
-    } else {
-      None
-    }
-
   private val batchHandleCommitPartitionEnabled = 
conf.batchHandleCommitPartitionEnabled
   private val batchHandleCommitPartitionExecutors = 
ThreadUtils.newDaemonCachedThreadPool(
     "rss-lifecycle-manager-commit-partition-executor",
@@ -208,6 +193,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       conf,
       rssHARetryClient,
       () => (totalWritten.sumThenReset(), fileCount.sumThenReset()))
+  private val changePartitionManager = new ChangePartitionManager(conf, this)
 
   // Since method `onStart` is executed when `rpcEnv.setupEndpoint` is 
executed, and
   // `rssHARetryClient` is initialized after `rpcEnv` is initialized, if 
method `onStart` contains
@@ -217,46 +203,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
   private def initialize(): Unit = {
     // noinspection ConvertExpressionToSAM
     heartbeater.start()
-    batchHandleChangePartitionSchedulerThread.foreach {
-      // noinspection ConvertExpressionToSAM
-      _.scheduleAtFixedRate(
-        new Runnable {
-          override def run(): Unit = {
-            try {
-              changePartitionRequests.asScala.foreach { case (shuffleId, 
requests) =>
-                requests.synchronized {
-                  batchHandleChangePartitionExecutors.submit {
-                    new Runnable {
-                      override def run(): Unit = {
-                        // For each partition only need handle one request
-                        val distinctPartitions = requests.asScala.filter { 
case (partitionId, _) =>
-                          
!inBatchPartitions.get(shuffleId).contains(partitionId)
-                        }.map { case (partitionId, request) =>
-                          inBatchPartitions.get(shuffleId).add(partitionId)
-                          request.asScala.toArray.maxBy(_.epoch)
-                        }.toArray
-                        if (distinctPartitions.nonEmpty) {
-                          batchHandleRequestPartitions(
-                            distinctPartitions.head.applicationId,
-                            shuffleId,
-                            distinctPartitions)
-                        }
-                      }
-                    }
-                  }
-                }
-              }
-            } catch {
-              case e: InterruptedException =>
-                logError("Partition split scheduler thread is shutting down, 
detail: ", e)
-                throw e
-            }
-          }
-        },
-        0,
-        batchHandleChangePartitionRequestInterval,
-        TimeUnit.MILLISECONDS)
-    }
+    changePartitionManager.start()
 
     batchHandleCommitPartitionSchedulerThread.foreach {
       _.scheduleAtFixedRate(
@@ -399,6 +346,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
     getBlacklist.cancel(true)
     ThreadUtils.shutdown(forwardMessageThread, 800.millis)
 
+    changePartitionManager.stop()
     heartbeater.stop()
 
     rssHARetryClient.close()
@@ -504,7 +452,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       val oldPartition = 
PbSerDeUtils.fromPbPartitionLocation(pb.getOldPartition)
       logTrace(s"Received split request, " +
         s"$applicationId, $shuffleId, $partitionId, $epoch, $oldPartition")
-      handleRequestPartitionLocation(
+      changePartitionManager.handleRequestPartitionLocation(
         ChangeLocationCallContext(context),
         applicationId,
         shuffleId,
@@ -720,7 +668,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       context.reply(RegisterShuffleResponse(StatusCode.SUCCESS, 
partitionLocations))
     } else {
       // request new resource for this task
-      handleRequestPartitionLocation(
+      changePartitionManager.handleRequestPartitionLocation(
         ApplyNewLocationCallContext(context),
         applicationId,
         shuffleId,
@@ -730,7 +678,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
     }
   }
 
-  private def blacklistPartition(
+  def blacklistPartition(
       shuffleId: Int,
       oldPartition: PartitionLocation,
       cause: StatusCode): Unit = {
@@ -778,7 +726,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
     logWarning(s"Do Revive for shuffle ${Utils.makeShuffleKey(applicationId, 
shuffleId)}, " +
       s"oldPartition: $oldPartition, cause: $cause")
 
-    handleRequestPartitionLocation(
+    changePartitionManager.handleRequestPartitionLocation(
       ChangeLocationCallContext(context),
       applicationId,
       shuffleId,
@@ -788,195 +736,6 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       Some(cause))
   }
 
-  private val rpcContextRegisterFunc =
-    new util.function.Function[
-      Int,
-      ConcurrentHashMap[Integer, util.Set[ChangePartitionRequest]]]() {
-      override def apply(s: Int): ConcurrentHashMap[Integer, 
util.Set[ChangePartitionRequest]] =
-        new ConcurrentHashMap()
-    }
-
-  private val inBatchShuffleIdRegisterFunc = new util.function.Function[Int, 
util.Set[Integer]]() {
-    override def apply(s: Int): util.Set[Integer] = new util.HashSet[Integer]()
-  }
-
-  private def handleRequestPartitionLocation(
-      context: RequestLocationCallContext,
-      applicationId: String,
-      shuffleId: Int,
-      partitionId: Int,
-      oldEpoch: Int,
-      oldPartition: PartitionLocation,
-      cause: Option[StatusCode] = None): Unit = {
-
-    val changePartition = ChangePartitionRequest(
-      context,
-      applicationId,
-      shuffleId,
-      partitionId,
-      oldEpoch,
-      oldPartition,
-      cause)
-    // check if there exists request for the partition, if do just register
-    val requests = changePartitionRequests.computeIfAbsent(shuffleId, 
rpcContextRegisterFunc)
-    inBatchPartitions.computeIfAbsent(shuffleId, inBatchShuffleIdRegisterFunc)
-
-    // handle hard split
-    if (batchHandleCommitPartitionEnabled && cause.isDefined && cause.get == 
StatusCode.HARD_SPLIT) {
-      val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
-      shuffleCommittedInfo.synchronized {
-        shuffleCommittedInfo.commitPartitionRequests
-          .add(CommitPartitionRequest(applicationId, shuffleId, oldPartition))
-      }
-    }
-
-    requests.synchronized {
-      if (requests.containsKey(partitionId)) {
-        requests.get(partitionId).add(changePartition)
-        logTrace(s"[handleRequestPartitionLocation] For $shuffleId, request 
for same partition" +
-          s"$partitionId-$oldEpoch exists, register context.")
-        return
-      } else {
-        // If new slot for the partition has been allocated, reply and return.
-        // Else register and allocate for it.
-        getLatestPartition(shuffleId, partitionId, oldEpoch).foreach { 
latestLoc =>
-          context.reply(StatusCode.SUCCESS, Some(latestLoc))
-          logDebug(s"New partition found, old partition $partitionId-$oldEpoch 
return it." +
-            s" shuffleId: $shuffleId $latestLoc")
-          return
-        }
-        val set = new util.HashSet[ChangePartitionRequest]()
-        set.add(changePartition)
-        requests.put(partitionId, set)
-      }
-    }
-    if (!batchHandleChangePartitionEnabled) {
-      batchHandleRequestPartitions(applicationId, shuffleId, 
Array(changePartition))
-    }
-  }
-
-  def batchHandleRequestPartitions(
-      applicationId: String,
-      shuffleId: Int,
-      changePartitions: Array[ChangePartitionRequest]): Unit = {
-    val requestsMap = changePartitionRequests.get(shuffleId)
-
-    val changes = changePartitions.map { change =>
-      s"${change.shuffleId}-${change.partitionId}-${change.epoch}"
-    }.mkString("[", ",", "]")
-    logWarning(s"Batch handle change partition for $applicationId of $changes")
-
-    // Blacklist all failed workers
-    if (changePartitions.exists(_.causes.isDefined)) {
-      changePartitions.filter(_.causes.isDefined).foreach { changePartition =>
-        blacklistPartition(shuffleId, changePartition.oldPartition, 
changePartition.causes.get)
-      }
-    }
-
-    // remove together to reduce lock time
-    def replySuccess(locations: Array[PartitionLocation]): Unit = {
-      requestsMap.synchronized {
-        locations.map { location =>
-          if (batchHandleChangePartitionEnabled) {
-            inBatchPartitions.get(shuffleId).remove(location.getId)
-          }
-          // Here one partition id can be remove more than once,
-          // so need to filter null result before reply.
-          location -> Option(requestsMap.remove(location.getId))
-        }
-      }.foreach { case (newLocation, requests) =>
-        requests.map(_.asScala.toList.foreach(_.context.reply(
-          StatusCode.SUCCESS,
-          Option(newLocation))))
-      }
-    }
-
-    // remove together to reduce lock time
-    def replyFailure(status: StatusCode): Unit = {
-      requestsMap.synchronized {
-        changePartitions.map { changePartition =>
-          if (batchHandleChangePartitionEnabled) {
-            
inBatchPartitions.get(shuffleId).remove(changePartition.partitionId)
-          }
-          Option(requestsMap.remove(changePartition.partitionId))
-        }
-      }.foreach { requests =>
-        requests.map(_.asScala.toList.foreach(_.context.reply(status, None)))
-      }
-    }
-
-    val candidates = workersNotBlacklisted(shuffleId)
-    if (candidates.size < 1 || (pushReplicateEnabled && candidates.size < 2)) {
-      logError("[Update partition] failed for not enough candidates for 
revive.")
-      replyFailure(StatusCode.SLOT_NOT_AVAILABLE)
-      return
-    }
-
-    // PartitionSplit all contains oldPartition
-    val newlyAllocatedLocations =
-      
reallocateChangePartitionRequestSlotsFromCandidates(changePartitions.toList, 
candidates)
-
-    if (!registeredShuffle.contains(shuffleId)) {
-      logError(s"[handleChangePartition] shuffle $shuffleId not registered!")
-      replyFailure(StatusCode.SHUFFLE_NOT_REGISTERED)
-      return
-    }
-
-    if (stageEndShuffleSet.contains(shuffleId)) {
-      logError(s"[handleChangePartition] shuffle $shuffleId already ended!")
-      replyFailure(StatusCode.STAGE_ENDED)
-      return
-    }
-
-    if (!reserveSlotsWithRetry(
-        applicationId,
-        shuffleId,
-        new util.HashSet(candidates.toSet.asJava),
-        newlyAllocatedLocations)) {
-      logError(s"[Update partition] failed for $shuffleId.")
-      replyFailure(StatusCode.RESERVE_SLOTS_FAILED)
-      return
-    }
-
-    val newMasterLocations =
-      newlyAllocatedLocations.asScala.flatMap {
-        case (workInfo, (masterLocations, slaveLocations)) =>
-          // Add all re-allocated slots to worker snapshots.
-          workerSnapshots(shuffleId).asScala.get(workInfo).foreach { 
partitionLocationInfo =>
-            partitionLocationInfo.addMasterPartitions(shuffleId.toString, 
masterLocations)
-            updateLatestPartitionLocations(shuffleId, masterLocations)
-            partitionLocationInfo.addSlavePartitions(shuffleId.toString, 
slaveLocations)
-          }
-          // partition location can be null when call reserveSlotsWithRetry().
-          val locations = (masterLocations.asScala ++ 
slaveLocations.asScala.map(_.getPeer))
-            .distinct.filter(_ != null)
-          if (locations.nonEmpty) {
-            val changes = locations.map { partition =>
-              s"(partition ${partition.getId} epoch from ${partition.getEpoch 
- 1} to ${partition.getEpoch})"
-            }.mkString("[", ", ", "]")
-            logDebug(s"[Update partition] success for " +
-              s"shuffle ${Utils.makeShuffleKey(applicationId, shuffleId)}, 
succeed partitions: " +
-              s"$changes.")
-          }
-          locations
-      }
-    replySuccess(newMasterLocations.toArray)
-  }
-
-  private def getLatestPartition(
-      shuffleId: Int,
-      partitionId: Int,
-      epoch: Int): Option[PartitionLocation] = {
-    val map = latestPartitionLocation.get(shuffleId)
-    if (map != null) {
-      val loc = map.get(partitionId)
-      if (loc != null && loc.getEpoch > epoch) {
-        return Some(loc)
-      }
-    }
-    None
-  }
-
   private def handleMapperEnd(
       context: RpcCallContext,
       applicationId: String,
@@ -1554,7 +1313,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
    * @param slots         the total allocated worker resources that need to be 
applied for the slot
    * @return If reserve all slots success
    */
-  private def reserveSlotsWithRetry(
+  def reserveSlotsWithRetry(
       applicationId: String,
       shuffleId: Int,
       candidates: util.HashSet[WorkerInfo],
@@ -1645,7 +1404,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
    * @param candidates WorkerInfo list can be used to offer worker slots
    * @param slots      Current WorkerResource
    */
-  private def allocateFromCandidates(
+  def allocateFromCandidates(
       id: Int,
       oldEpochId: Int,
       candidates: List[WorkerInfo],
@@ -1683,16 +1442,6 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
     masterAndSlavePairs._1.add(masterLocation)
   }
 
-  private def reallocateChangePartitionRequestSlotsFromCandidates(
-      changePartitionRequests: List[ChangePartitionRequest],
-      candidates: List[WorkerInfo]): WorkerResource = {
-    val slots = new WorkerResource()
-    changePartitionRequests.foreach { partition =>
-      allocateFromCandidates(partition.partitionId, partition.epoch, 
candidates, slots)
-    }
-    slots
-  }
-
   private def reallocateSlotsFromCandidates(
       oldPartitions: List[PartitionLocation],
       candidates: List[WorkerInfo],
@@ -1750,12 +1499,11 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
         dataLostShuffleSet.remove(shuffleId)
         shuffleMapperAttempts.remove(shuffleId)
         stageEndShuffleSet.remove(shuffleId)
-        changePartitionRequests.remove(shuffleId)
-        inBatchPartitions.remove(shuffleId)
         committedPartitionInfo.remove(shuffleId)
         unregisterShuffleTime.remove(shuffleId)
         shuffleAllocatedWorkers.remove(shuffleId)
         latestPartitionLocation.remove(shuffleId)
+        changePartitionManager.removeExpiredShuffle(shuffleId)
 
         requestUnregisterShuffle(
           rssHARetryClient,
@@ -1920,8 +1668,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
     }
   }
 
-  private def recordWorkerFailure(failures: ConcurrentHashMap[WorkerInfo, 
(StatusCode, Long)])
-      : Unit = {
+  def recordWorkerFailure(failures: ConcurrentHashMap[WorkerInfo, (StatusCode, 
Long)]): Unit = {
     val failedWorker = new ConcurrentHashMap[WorkerInfo, (StatusCode, 
Long)](failures)
     logInfo(s"Report Worker Failure: ${failedWorker.asScala}, current 
blacklist $blacklist")
     failedWorker.asScala.foreach { case (worker, (statusCode, registerTime)) =>
@@ -1961,14 +1708,6 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
     }
   }
 
-  private def workersNotBlacklisted(shuffleId: Int): List[WorkerInfo] = {
-    workerSnapshots(shuffleId)
-      .keySet()
-      .asScala
-      .filter(w => !blacklist.keySet().contains(w))
-      .toList
-  }
-
   // Initialize at the end of LifecycleManager construction.
   initialize()
 }


Reply via email to