This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new e3576e4e [CELEBORN-117] refactor CommitManager, implements M/R 
Partition Commi… (#1060)
e3576e4e is described below

commit e3576e4e7af10485c92f55ea8e585e986f23260b
Author: Shuang <[email protected]>
AuthorDate: Thu Dec 15 11:09:59 2022 +0800

    [CELEBORN-117] refactor CommitManager, implements M/R Partition Commi… 
(#1060)
---
 .../celeborn/client/ChangePartitionManager.scala   |   2 +-
 .../org/apache/celeborn/client/CommitManager.scala | 485 +++------------------
 .../apache/celeborn/client/LifecycleManager.scala  |  70 +--
 .../celeborn/client/commit/CommitHandler.scala     | 370 ++++++++++++++++
 .../client/commit/MapPartitionCommitHandler.scala  | 177 ++++++++
 .../commit/ReducePartitionCommitHandler.scala      | 176 ++++++++
 .../common/meta/PartitionLocationInfo.scala        |  31 +-
 7 files changed, 838 insertions(+), 473 deletions(-)

diff --git 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
index 5a3711a0..960ac1e6 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
@@ -271,7 +271,7 @@ class ChangePartitionManager(
       return
     }
 
-    if (lifecycleManager.commitManager.stageEndShuffleSet.contains(shuffleId)) 
{
+    if (lifecycleManager.commitManager.isStageEnd(shuffleId)) {
       logError(s"[handleChangePartition] shuffle $shuffleId already ended!")
       replyFailure(StatusCode.STAGE_ENDED)
       return
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index 4bb2ffd9..21551871 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -19,22 +19,24 @@ package org.apache.celeborn.client
 
 import java.util
 import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, 
ScheduledFuture, TimeUnit}
-import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, LongAdder}
+import java.util.concurrent.atomic.{AtomicInteger, LongAdder}
 
 import scala.collection.JavaConverters._
 import scala.concurrent.duration.DurationInt
 
 import org.roaringbitmap.RoaringBitmap
 
+import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
+import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers
+import org.apache.celeborn.client.commit.{CommitHandler, 
MapPartitionCommitHandler, ReducePartitionCommitHandler}
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
+import org.apache.celeborn.common.meta.WorkerInfo
 import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType, 
StorageInfo}
-import 
org.apache.celeborn.common.protocol.message.ControlMessages.{CommitFiles, 
CommitFilesResponse}
 import org.apache.celeborn.common.protocol.message.StatusCode
-import org.apache.celeborn.common.rpc.RpcEndpointRef
-import org.apache.celeborn.common.util.{ThreadUtils, Utils}
+// Can Remove this if celeborn don't support scala211 in future
 import org.apache.celeborn.common.util.FunctionConverter._
+import org.apache.celeborn.common.util.ThreadUtils
 
 case class CommitPartitionRequest(
     applicationId: String,
@@ -55,16 +57,15 @@ case class ShuffleCommittedInfo(
     allInFlightCommitRequestNum: AtomicInteger,
     partitionInFlightCommitRequestNum: ConcurrentHashMap[Int, AtomicInteger])
 
+object CommitManager {
+  type CommittedPartitionInfo = ConcurrentHashMap[Int, ShuffleCommittedInfo]
+}
+
 class CommitManager(appId: String, val conf: CelebornConf, lifecycleManager: 
LifecycleManager)
   extends Logging {
+
   // shuffle id -> ShuffleCommittedInfo
-  private val committedPartitionInfo = new ConcurrentHashMap[Int, 
ShuffleCommittedInfo]()
-  val dataLostShuffleSet = ConcurrentHashMap.newKeySet[Int]()
-  val stageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
-  private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
-  // shuffleId -> in processing partitionId set
-  private val inProcessMapPartitionEndIds = new ConcurrentHashMap[Int, 
util.Set[Int]]()
-  private val pushReplicateEnabled = conf.pushReplicateEnabled
+  private val committedPartitionInfo = new CommittedPartitionInfo
 
   private val batchHandleCommitPartitionEnabled = 
conf.batchHandleCommitPartitionEnabled
   private val batchHandleCommitPartitionExecutors = 
ThreadUtils.newDaemonCachedThreadPool(
@@ -80,12 +81,7 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
       None
     }
   private var batchHandleCommitPartition: Option[ScheduledFuture[_]] = _
-
-  private val totalWritten = new LongAdder
-  private val fileCount = new LongAdder
-
-  private val testRetryCommitFiles = conf.testRetryCommitFiles
-  private val commitEpoch = new AtomicLong()
+  private val commitHandlers = new ConcurrentHashMap[PartitionType, 
CommitHandler]()
 
   def start(): Unit = {
     batchHandleCommitPartition = batchHandleCommitPartitionSchedulerThread.map 
{
@@ -96,15 +92,7 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
               batchHandleCommitPartitionExecutors.submit {
                 new Runnable {
                   val partitionType = 
lifecycleManager.getPartitionType(shuffleId)
-                  def isPartitionInProcess(partitionId: Int): Boolean = {
-                    if (inProcessMapPartitionEndIds.containsKey(shuffleId) &&
-                      
inProcessMapPartitionEndIds.get(shuffleId).contains(partitionId)) {
-                      true
-                    } else {
-                      false
-                    }
-                  }
-
+                  val commitHandler = getCommitHandler(shuffleId)
                   def incrementInflightNum(workerToRequests: Map[
                     WorkerInfo,
                     collection.Set[PartitionLocation]]): Unit = {
@@ -145,7 +133,8 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
                     if (partitionType == PartitionType.MAP) {
                       commitPartitionRequests.asScala.filterNot { request =>
                         shuffleCommittedInfo.handledCommitPartitionRequests
-                          .contains(request.partition) && isPartitionInProcess(
+                          .contains(request.partition) && 
commitHandler.isPartitionInProcess(
+                          shuffleId,
                           request.partition.getId)
                       }
                     } else {
@@ -163,8 +152,7 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
                       // partitions which are already committed by stageEnd 
process.
                       // But inProcessStageEndShuffleSet should have contain 
this shuffle id,
                       // can directly return.
-                      if (inProcessStageEndShuffleSet.contains(shuffleId) ||
-                        stageEndShuffleSet.contains(shuffleId)) {
+                      if (commitHandler.isStageEndOrInProcess(shuffleId)) {
                         logWarning(s"Shuffle $shuffleId ended or during 
processing stage end.")
                         shuffleCommittedInfo.commitPartitionRequests.clear()
                         Map.empty[WorkerInfo, Set[PartitionLocation]]
@@ -199,8 +187,7 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
                       }
                     }
                     if (workerToRequests.nonEmpty) {
-                      val commitFilesFailedWorkers =
-                        new ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]()
+                      val commitFilesFailedWorkers = new ShuffleFailedWorkers()
                       val parallelism = workerToRequests.size
                       try {
                         ThreadUtils.parmap(
@@ -228,7 +215,7 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
                                 .toList
                                 .asJava
 
-                            commitFiles(
+                            commitHandler.commitFiles(
                               appId,
                               shuffleId,
                               shuffleCommittedInfo,
@@ -279,10 +266,7 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
 
   def removeExpiredShuffle(shuffleId: Int): Unit = {
     committedPartitionInfo.remove(shuffleId)
-    inProcessStageEndShuffleSet.remove(shuffleId)
-    inProcessMapPartitionEndIds.remove(shuffleId)
-    dataLostShuffleSet.remove(shuffleId)
-    stageEndShuffleSet.remove(shuffleId)
+    getCommitHandler(shuffleId).removeExpiredShuffle(shuffleId)
   }
 
   def registerCommitPartitionRequest(
@@ -299,401 +283,68 @@ class CommitManager(appId: String, val conf: 
CelebornConf, lifecycleManager: Lif
     }
   }
 
-  private def commitFiles(
-      applicationId: String,
-      shuffleId: Int,
-      shuffleCommittedInfo: ShuffleCommittedInfo,
-      worker: WorkerInfo,
-      masterIds: util.List[String],
-      slaveIds: util.List[String],
-      commitFilesFailedWorkers: ConcurrentHashMap[WorkerInfo, (StatusCode, 
Long)]): Unit = {
-
-    val res =
-      if (!testRetryCommitFiles) {
-        val commitFiles = CommitFiles(
-          applicationId,
-          shuffleId,
-          masterIds,
-          slaveIds,
-          lifecycleManager.shuffleMapperAttempts.get(shuffleId),
-          commitEpoch.incrementAndGet())
-        val res = requestCommitFilesWithRetry(worker.endpoint, commitFiles)
-
-        res.status match {
-          case StatusCode.SUCCESS => // do nothing
-          case StatusCode.PARTIAL_SUCCESS | StatusCode.SHUFFLE_NOT_REGISTERED 
| StatusCode.FAILED =>
-            logDebug(s"Request $commitFiles return ${res.status} for " +
-              s"${Utils.makeShuffleKey(applicationId, shuffleId)}")
-            commitFilesFailedWorkers.put(worker, (res.status, 
System.currentTimeMillis()))
-          case _ => // won't happen
-        }
-        res
-      } else {
-        // for test
-        val commitFiles1 = CommitFiles(
-          applicationId,
-          shuffleId,
-          masterIds.subList(0, masterIds.size() / 2),
-          slaveIds.subList(0, slaveIds.size() / 2),
-          lifecycleManager.shuffleMapperAttempts.get(shuffleId),
-          commitEpoch.incrementAndGet())
-        val res1 = requestCommitFilesWithRetry(worker.endpoint, commitFiles1)
-
-        val commitFiles = CommitFiles(
-          applicationId,
-          shuffleId,
-          masterIds.subList(masterIds.size() / 2, masterIds.size()),
-          slaveIds.subList(slaveIds.size() / 2, slaveIds.size()),
-          lifecycleManager.shuffleMapperAttempts.get(shuffleId),
-          commitEpoch.incrementAndGet())
-        val res2 = requestCommitFilesWithRetry(worker.endpoint, commitFiles)
-
-        
res1.committedMasterStorageInfos.putAll(res2.committedMasterStorageInfos)
-        res1.committedSlaveStorageInfos.putAll(res2.committedSlaveStorageInfos)
-        res1.committedMapIdBitMap.putAll(res2.committedMapIdBitMap)
-        CommitFilesResponse(
-          status = if (res1.status == StatusCode.SUCCESS) res2.status else 
res1.status,
-          (res1.committedMasterIds.asScala ++ 
res2.committedMasterIds.asScala).toList.asJava,
-          (res1.committedSlaveIds.asScala ++ 
res1.committedSlaveIds.asScala).toList.asJava,
-          (res1.failedMasterIds.asScala ++ 
res1.failedMasterIds.asScala).toList.asJava,
-          (res1.failedSlaveIds.asScala ++ 
res2.failedSlaveIds.asScala).toList.asJava,
-          res1.committedMasterStorageInfos,
-          res1.committedSlaveStorageInfos,
-          res1.committedMapIdBitMap,
-          res1.totalWritten + res2.totalWritten,
-          res1.fileCount + res2.fileCount)
-      }
-
-    shuffleCommittedInfo.synchronized {
-      // record committed partitionIds
-      res.committedMasterIds.asScala.foreach({
-        case commitMasterId =>
-          val partitionUniqueIdList = 
shuffleCommittedInfo.committedMasterIds.computeIfAbsent(
-            Utils.splitPartitionLocationUniqueId(commitMasterId)._1,
-            (k: Int) => new util.ArrayList[String]())
-          partitionUniqueIdList.add(commitMasterId)
-      })
-
-      res.committedSlaveIds.asScala.foreach({
-        case commitSlaveId =>
-          val partitionUniqueIdList = 
shuffleCommittedInfo.committedSlaveIds.computeIfAbsent(
-            Utils.splitPartitionLocationUniqueId(commitSlaveId)._1,
-            (k: Int) => new util.ArrayList[String]())
-          partitionUniqueIdList.add(commitSlaveId)
-      })
-
-      // record committed partitions storage hint and disk hint
-      
shuffleCommittedInfo.committedMasterStorageInfos.putAll(res.committedMasterStorageInfos)
-      
shuffleCommittedInfo.committedSlaveStorageInfos.putAll(res.committedSlaveStorageInfos)
-
-      // record failed partitions
-      shuffleCommittedInfo.failedMasterPartitionIds.putAll(
-        res.failedMasterIds.asScala.map((_, worker)).toMap.asJava)
-      shuffleCommittedInfo.failedSlavePartitionIds.putAll(
-        res.failedSlaveIds.asScala.map((_, worker)).toMap.asJava)
-
-      
shuffleCommittedInfo.committedMapIdBitmap.putAll(res.committedMapIdBitMap)
-
-      totalWritten.add(res.totalWritten)
-      fileCount.add(res.fileCount)
-      shuffleCommittedInfo.currentShuffleFileCount.add(res.fileCount)
-    }
-  }
-
-  private def requestCommitFilesWithRetry(
-      endpoint: RpcEndpointRef,
-      message: CommitFiles): CommitFilesResponse = {
-    val maxRetries = conf.requestCommitFilesMaxRetries
-    var retryTimes = 0
-    while (retryTimes < maxRetries) {
-      try {
-        if (testRetryCommitFiles && retryTimes < maxRetries - 1) {
-          endpoint.ask[CommitFilesResponse](message)
-          Thread.sleep(1000)
-          throw new Exception("Mock fail for CommitFiles")
-        } else {
-          return endpoint.askSync[CommitFilesResponse](message)
-        }
-      } catch {
-        case e: Throwable =>
-          retryTimes += 1
-          logError(
-            s"AskSync CommitFiles for ${message.shuffleId} failed (attempt 
$retryTimes/$maxRetries).",
-            e)
-      }
-    }
-
-    CommitFilesResponse(
-      StatusCode.FAILED,
-      List.empty.asJava,
-      List.empty.asJava,
-      message.masterIds,
-      message.slaveIds)
+  def tryFinalCommit(shuffleId: Int): Boolean = {
+    getCommitHandler(shuffleId).tryFinalCommit(
+      shuffleId,
+      r => lifecycleManager.recordWorkerFailure(r))
   }
 
-  def finalCommit(
-      applicationId: String,
-      shuffleId: Int,
-      fileGroups: ConcurrentHashMap[Integer, util.Set[PartitionLocation]]): 
Unit = {
-    if (stageEndShuffleSet.contains(shuffleId)) {
-      logInfo(s"[handleStageEnd] Shuffle $shuffleId already ended!")
-      return
-    }
-    inProcessStageEndShuffleSet.synchronized {
-      if (inProcessStageEndShuffleSet.contains(shuffleId)) {
-        logWarning(s"[handleStageEnd] Shuffle $shuffleId is in process!")
-        return
-      }
-      inProcessStageEndShuffleSet.add(shuffleId)
-    }
-    // ask allLocations workers holding partitions to commit files
-    val allocatedWorkers = 
lifecycleManager.shuffleAllocatedWorkers.get(shuffleId)
-    val dataLost = handleCommitFiles(applicationId, shuffleId, 
allocatedWorkers, None, fileGroups)
-
-    // reply
-    if (!dataLost) {
-      logInfo(s"Succeed to handle stageEnd for $shuffleId.")
-      // record in stageEndShuffleSet
-      stageEndShuffleSet.add(shuffleId)
-    } else {
-      logError(s"Failed to handle stageEnd for $shuffleId, lost file!")
-      dataLostShuffleSet.add(shuffleId)
-      // record in stageEndShuffleSet
-      stageEndShuffleSet.add(shuffleId)
-    }
-    inProcessStageEndShuffleSet.remove(shuffleId)
+  def finalPartitionCommit(shuffleId: Int, partitionId: Int): Boolean = {
+    getCommitHandler(shuffleId).finalPartitionCommit(
+      shuffleId,
+      partitionId,
+      r => lifecycleManager.recordWorkerFailure(r))
   }
 
-  private def handleCommitFiles(
-      applicationId: String,
-      shuffleId: Int,
-      allocatedWorkers: util.Map[WorkerInfo, PartitionLocationInfo],
-      partitionIdOpt: Option[Int] = None,
-      fileGroups: ConcurrentHashMap[Integer, util.Set[PartitionLocation]]): 
Boolean = {
-    val masterPartMap = new ConcurrentHashMap[String, PartitionLocation]
-    val slavePartMap = new ConcurrentHashMap[String, PartitionLocation]
-    val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
-    val commitFilesFailedWorkers = new ConcurrentHashMap[WorkerInfo, 
(StatusCode, Long)]()
-    val commitFileStartTime = System.nanoTime()
-
-    val parallelism = Math.min(allocatedWorkers.size(), conf.rpcMaxParallelism)
-    ThreadUtils.parmap(
-      allocatedWorkers.asScala.to,
-      "CommitFiles",
-      parallelism) { case (worker, partitionLocationInfo) =>
-      if (partitionLocationInfo.containsShuffle(shuffleId.toString)) {
-        val masterParts =
-          partitionLocationInfo.getMasterLocations(shuffleId.toString, 
partitionIdOpt)
-        val slaveParts = 
partitionLocationInfo.getSlaveLocations(shuffleId.toString, partitionIdOpt)
-        masterParts.asScala.foreach { p =>
-          val partition = new PartitionLocation(p)
-          partition.setFetchPort(worker.fetchPort)
-          partition.setPeer(null)
-          masterPartMap.put(partition.getUniqueId, partition)
-        }
-        slaveParts.asScala.foreach { p =>
-          val partition = new PartitionLocation(p)
-          partition.setFetchPort(worker.fetchPort)
-          partition.setPeer(null)
-          slavePartMap.put(partition.getUniqueId, partition)
-        }
-
-        val (masterIds, slaveIds) = shuffleCommittedInfo.synchronized {
-          (
-            masterParts.asScala
-              
.filterNot(shuffleCommittedInfo.handledCommitPartitionRequests.contains)
-              .map(_.getUniqueId).asJava,
-            slaveParts.asScala
-              
.filterNot(shuffleCommittedInfo.handledCommitPartitionRequests.contains)
-              .map(_.getUniqueId).asJava)
-        }
-
-        commitFiles(
-          applicationId,
-          shuffleId,
-          shuffleCommittedInfo,
-          worker,
-          masterIds,
-          slaveIds,
-          commitFilesFailedWorkers)
-      }
-    }
-    lifecycleManager.recordWorkerFailure(commitFilesFailedWorkers)
-    // check all inflight request complete, for map partition, it's for single 
partitionId
-    waitInflightRequestComplete(shuffleId, shuffleCommittedInfo, 
partitionIdOpt)
-
-    // check all data lost or not, for map partition, it's for single 
partitionId
-    val dataLost = checkDataLost(applicationId, shuffleId, partitionIdOpt)
-
-    if (!dataLost) {
-      val committedPartitions = new util.HashMap[String, PartitionLocation]
-      getPartitionUniqueIds(shuffleCommittedInfo.committedMasterIds, 
partitionIdOpt).foreach { id =>
-        if (shuffleCommittedInfo.committedMasterStorageInfos.get(id) == null) {
-          logDebug(s"$applicationId-$shuffleId $id storage hint was not 
returned")
-        } else {
-          masterPartMap.get(id).setStorageInfo(
-            shuffleCommittedInfo.committedMasterStorageInfos.get(id))
-          
masterPartMap.get(id).setMapIdBitMap(shuffleCommittedInfo.committedMapIdBitmap.get(id))
-          committedPartitions.put(id, masterPartMap.get(id))
-        }
-      }
-
-      getPartitionUniqueIds(shuffleCommittedInfo.committedSlaveIds, 
partitionIdOpt).foreach { id =>
-        val slavePartition = slavePartMap.get(id)
-        if (shuffleCommittedInfo.committedSlaveStorageInfos.get(id) == null) {
-          logDebug(s"$applicationId-$shuffleId $id storage hint was not 
returned")
-        } else {
-          
slavePartition.setStorageInfo(shuffleCommittedInfo.committedSlaveStorageInfos.get(id))
-          val masterPartition = committedPartitions.get(id)
-          if (masterPartition ne null) {
-            masterPartition.setPeer(slavePartition)
-            slavePartition.setPeer(masterPartition)
-          } else {
-            logInfo(s"Shuffle $shuffleId partition $id: master lost, " +
-              s"use slave $slavePartition.")
-            
slavePartition.setMapIdBitMap(shuffleCommittedInfo.committedMapIdBitmap.get(id))
-            committedPartitions.put(id, slavePartition)
-          }
-        }
-      }
-
-      committedPartitions.values().asScala.foreach { partition =>
-        val partitionLocations = fileGroups.computeIfAbsent(
-          partition.getId,
-          (k: Integer) => new util.HashSet[PartitionLocation]())
-        partitionLocations.add(partition)
-      }
-
-      logInfo(s"Shuffle $shuffleId " +
-        s"commit files complete. File count 
${shuffleCommittedInfo.currentShuffleFileCount.sum()} " +
-        s"using ${(System.nanoTime() - commitFileStartTime) / 1000000} ms")
-    }
-
-    dataLost
+  def isStageEnd(shuffleId: Int): Boolean = {
+    getCommitHandler(shuffleId).isStageEnd(shuffleId)
   }
 
-  private def getPartitionIds(
-      partitionIds: ConcurrentHashMap[String, WorkerInfo],
-      partitionIdOpt: Option[Int]): util.Map[String, WorkerInfo] = {
-    partitionIdOpt match {
-      case Some(partitionId) => partitionIds.asScala.filter(p =>
-          Utils.splitPartitionLocationUniqueId(p._1)._1 ==
-            partitionId).asJava
-      case None => partitionIds
-    }
+  def setStageEnd(shuffleId: Int): Unit = {
+    getCommitHandler(shuffleId).setStageEnd(shuffleId)
   }
 
-  private def waitInflightRequestComplete(
-      shuffleId: Int,
-      shuffleCommittedInfo: ShuffleCommittedInfo,
-      partitionIdOpt: Option[Int]): Unit = {
-    lifecycleManager.getPartitionType(shuffleId) match {
-      case PartitionType.REDUCE =>
-        while (shuffleCommittedInfo.allInFlightCommitRequestNum.get() > 0) {
-          Thread.sleep(1000)
-        }
-      case PartitionType.MAP => partitionIdOpt match {
-          case Some(partitionId) =>
-            if 
(shuffleCommittedInfo.partitionInFlightCommitRequestNum.containsKey(partitionId))
 {
-              while 
(shuffleCommittedInfo.partitionInFlightCommitRequestNum.get(
-                  partitionId).get() > 0) {
-                Thread.sleep(1000)
-              }
-            }
-        }
-    }
+  def isStageDataLost(shuffleId: Int): Boolean = {
+    getCommitHandler(shuffleId).isStageDataLost(shuffleId)
   }
 
-  private def checkDataLost(
-      applicationId: String,
-      shuffleId: Int,
-      partitionIdOpt: Option[Int]): Boolean = {
-    val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
-    val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId)
-    val masterPartitionUniqueIdMap =
-      getPartitionIds(shuffleCommittedInfo.failedMasterPartitionIds, 
partitionIdOpt)
-    if (!pushReplicateEnabled && masterPartitionUniqueIdMap.size() != 0) {
-      val msg =
-        masterPartitionUniqueIdMap.asScala.map {
-          case (partitionUniqueId, workerInfo) =>
-            s"Lost partition $partitionUniqueId in worker 
[${workerInfo.readableAddress()}]"
-        }.mkString("\n")
-      logError(
-        s"""
-           |For shuffle $shuffleKey partition data lost:
-           |$msg
-           |""".stripMargin)
-      true
+  private def getCommitHandler(shuffleId: Int): CommitHandler = {
+    val partitionType = lifecycleManager.getPartitionType(shuffleId)
+    if (commitHandlers.containsKey(partitionType)) {
+      commitHandlers.get(partitionType)
     } else {
-      val slavePartitionUniqueIdMap =
-        getPartitionIds(shuffleCommittedInfo.failedSlavePartitionIds, 
partitionIdOpt)
-      val failedBothPartitionIdsToWorker = 
masterPartitionUniqueIdMap.asScala.flatMap {
-        case (partitionUniqueId, worker) =>
-          if (slavePartitionUniqueIdMap.asScala.contains(partitionUniqueId)) {
-            Some(partitionUniqueId -> (worker, 
slavePartitionUniqueIdMap.get(partitionUniqueId)))
-          } else {
-            None
+      commitHandlers.computeIfAbsent(
+        partitionType,
+        (partitionType: PartitionType) => {
+          partitionType match {
+            case PartitionType.REDUCE => new ReducePartitionCommitHandler(
+                appId,
+                conf,
+                lifecycleManager.shuffleAllocatedWorkers,
+                lifecycleManager.reducerFileGroupsMap,
+                committedPartitionInfo,
+                lifecycleManager.shuffleMapperAttempts)
+            case PartitionType.MAP => new MapPartitionCommitHandler(
+                appId,
+                conf,
+                lifecycleManager
+                  .shuffleAllocatedWorkers,
+                lifecycleManager.reducerFileGroupsMap,
+                committedPartitionInfo)
+            case _ => throw new UnsupportedOperationException(
+                s"Unexpected ShufflePartitionType for CommitManager: 
$partitionType")
           }
-      }
-      if (failedBothPartitionIdsToWorker.nonEmpty) {
-        val msg = failedBothPartitionIdsToWorker.map {
-          case (partitionUniqueId, (masterWorker, slaveWorker)) =>
-            s"Lost partition $partitionUniqueId " +
-              s"in master worker [${masterWorker.readableAddress()}] and slave 
worker [$slaveWorker]"
-        }.mkString("\n")
-        logError(
-          s"""
-             |For shuffle $shuffleKey partition data lost:
-             |$msg
-             |""".stripMargin)
-        true
-      } else {
-        false
-      }
+        })
     }
   }
 
-  private def getPartitionUniqueIds(
-      ids: ConcurrentHashMap[Int, util.List[String]],
-      partitionIdOpt: Option[Int]): Iterable[String] = {
-    partitionIdOpt match {
-      case Some(partitionId) => ids.asScala.filter(_._1 == 
partitionId).flatMap(_._2.asScala)
-      case None => ids.asScala.flatMap(_._2.asScala)
+  def commitMetrics(): (Long, Long) = {
+    var totalWritten = 0L
+    var totalFileCount = 0L
+    commitHandlers.asScala.values.foreach { commitHandler =>
+      totalWritten += commitHandler.commitMetrics._1
+      totalFileCount += commitHandler.commitMetrics._2
     }
+    (totalWritten, totalFileCount)
   }
-
-  def finalPartitionCommit(
-      applicationId: String,
-      shuffleId: Int,
-      fileGroups: ConcurrentHashMap[Integer, util.Set[PartitionLocation]],
-      partitionId: Int): Boolean = {
-    val inProcessingPartitionIds =
-      inProcessMapPartitionEndIds.computeIfAbsent(shuffleId, (k: Int) => new 
util.HashSet[Int]())
-    inProcessingPartitionIds.add(partitionId)
-
-    val allocatedWorkers =
-      lifecycleManager.shuffleAllocatedWorkers.get(shuffleId).asScala.filter(p 
=>
-        p._2.containsRelatedShuffleOrPartition(shuffleId.toString, 
Option(partitionId))).asJava
-
-    var dataCommitSuccess = true
-    if (!allocatedWorkers.isEmpty) {
-      dataCommitSuccess =
-        !handleCommitFiles(
-          applicationId,
-          shuffleId,
-          allocatedWorkers,
-          Option(partitionId),
-          fileGroups)
-    }
-
-    // release resources and clear worker info
-    allocatedWorkers.asScala.foreach { case (_, partitionLocationInfo) =>
-      partitionLocationInfo.removeAllRelatedPartitions(shuffleId.toString, 
Option(partitionId))
-    }
-    inProcessingPartitionIds.remove(partitionId)
-
-    dataCommitSuccess
-  }
-
-  def commitMetrics(): (Long, Long) = (totalWritten.sumThenReset(), 
fileCount.sumThenReset())
 }
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 287ecd30..06fbb2a5 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -29,6 +29,7 @@ import scala.util.Random
 import com.google.common.annotations.VisibleForTesting
 import com.google.common.cache.{Cache, CacheBuilder}
 
+import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, 
ShuffleFailedWorkers, ShuffleFileGroups, ShuffleMapperAttempts}
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.haclient.RssHARetryClient
 import org.apache.celeborn.common.identity.{IdentityProvider, UserIdentifier}
@@ -41,8 +42,18 @@ import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.rpc._
 import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, 
RemoteNettyRpcCallContext}
 import org.apache.celeborn.common.util.{PbSerDeUtils, ThreadUtils, Utils}
+// Can Remove this if celeborn don't support scala211 in future
 import org.apache.celeborn.common.util.FunctionConverter._
 
+object LifecycleManager {
+  type ShuffleFileGroups =
+    ConcurrentHashMap[Int, ConcurrentHashMap[Integer, 
util.Set[PartitionLocation]]]
+  type ShuffleAllocatedWorkers =
+    ConcurrentHashMap[Int, ConcurrentHashMap[WorkerInfo, 
PartitionLocationInfo]]
+  type ShuffleMapperAttempts = ConcurrentHashMap[Int, Array[Int]]
+  type ShuffleFailedWorkers = ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]
+}
+
 class LifecycleManager(appId: String, val conf: CelebornConf) extends 
RpcEndpoint with Logging {
 
   private val lifecycleHost = Utils.localHostName
@@ -63,13 +74,11 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
   private val rpcCacheExpireTime = conf.rpcCacheExpireTime
 
   val registeredShuffle = ConcurrentHashMap.newKeySet[Int]()
-  val shuffleMapperAttempts = new ConcurrentHashMap[Int, Array[Int]]()
-  private val reducerFileGroupsMap =
-    new ConcurrentHashMap[Int, ConcurrentHashMap[Integer, 
util.Set[PartitionLocation]]]()
-  private val shuffleTaskInfo = new ShuffleTaskInfo()
+  val shuffleMapperAttempts = new ShuffleMapperAttempts
+  val reducerFileGroupsMap = new ShuffleFileGroups
+  private val shuffleTaskInfo = new ShuffleTaskInfo
   // maintain each shuffle's map relation of WorkerInfo and partition location
-  val shuffleAllocatedWorkers =
-    new ConcurrentHashMap[Int, ConcurrentHashMap[WorkerInfo, 
PartitionLocationInfo]]()
+  val shuffleAllocatedWorkers = new ShuffleAllocatedWorkers
   // shuffle id -> (partitionId -> newest PartitionLocation)
   val latestPartitionLocation =
     new ConcurrentHashMap[Int, ConcurrentHashMap[Int, PartitionLocation]]()
@@ -115,7 +124,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
     new ConcurrentHashMap[Int, util.Set[RegisterCallContext]]()
 
   // blacklist
-  val blacklist = new ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]()
+  val blacklist = new ShuffleFailedWorkers()
 
   // Threads
   private val forwardMessageThread =
@@ -426,7 +435,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
     // won't be empty since master will reply SlotNotAvailable status when 
reserved slots is empty.
     val slots = res.workerResource
     val candidatesWorkers = new util.HashSet(slots.keySet())
-    val connectFailedWorkers = new ConcurrentHashMap[WorkerInfo, (StatusCode, 
Long)]()
+    val connectFailedWorkers = new ShuffleFailedWorkers()
 
     // Second, for each worker, try to initialize the endpoint.
     val parallelism = Math.min(Math.max(1, slots.size()), 
conf.rpcMaxParallelism)
@@ -524,7 +533,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       oldPartition: PartitionLocation,
       cause: StatusCode): Unit = {
     // only blacklist if cause is PushDataFailMain
-    val failedWorker = new ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]()
+    val failedWorker = new ShuffleFailedWorkers()
     if (cause == StatusCode.PUSH_DATA_FAIL_MASTER && oldPartition != null) {
       val tmpWorker = oldPartition.getWorker
       val worker = workerSnapshots(shuffleId).keySet().asScala
@@ -629,7 +638,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
     val delta = 100
     // reduce partition need wait stage end. While map partition Would commit 
every partition synchronously.
     if (getPartitionType(shuffleId) == PartitionType.REDUCE) {
-      while (!commitManager.stageEndShuffleSet.contains(shuffleId)) {
+      while (!commitManager.isStageEnd(shuffleId)) {
         Thread.sleep(delta)
         if (timeout <= 0) {
           logError(s"[handleGetReducerFileGroup] Wait for handleStageEnd 
Timeout! $shuffleId.")
@@ -646,7 +655,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
         s" ${stageEndTimeout - timeout}ms")
     }
 
-    if (commitManager.dataLostShuffleSet.contains(shuffleId)) {
+    if (commitManager.isStageDataLost(shuffleId)) {
       context.reply(
         GetReducerFileGroupResponse(
           StatusCode.SHUFFLE_DATA_LOST,
@@ -686,21 +695,20 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       logInfo(s"[handleStageEnd]" +
         s"$shuffleId not registered, maybe no shuffle data within this stage.")
       // record in stageEndShuffleSet
-      commitManager.stageEndShuffleSet.add(shuffleId)
+      commitManager.setStageEnd(shuffleId)
       return
     }
-    commitManager.finalCommit(
-      applicationId,
-      shuffleId,
-      reducerFileGroupsMap.get(shuffleId))
-    // release resources and clear worker info
-    workerSnapshots(shuffleId).asScala.foreach { case (_, 
partitionLocationInfo) =>
-      partitionLocationInfo.removeMasterPartitions(shuffleId.toString)
-      partitionLocationInfo.removeSlavePartitions(shuffleId.toString)
+
+    if (commitManager.tryFinalCommit(shuffleId)) {
+      // release resources and clear worker info
+      workerSnapshots(shuffleId).asScala.foreach { case (_, 
partitionLocationInfo) =>
+        partitionLocationInfo.removeMasterPartitions(shuffleId.toString)
+        partitionLocationInfo.removeSlavePartitions(shuffleId.toString)
+      }
+      requestReleaseSlots(
+        rssHARetryClient,
+        ReleaseSlots(applicationId, shuffleId, List.empty.asJava, 
List.empty.asJava))
     }
-    requestReleaseSlots(
-      rssHARetryClient,
-      ReleaseSlots(applicationId, shuffleId, List.empty.asJava, 
List.empty.asJava))
   }
 
   private def handleMapPartitionEnd(
@@ -712,7 +720,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       partitionId: Int): Unit = {
     def reply(result: Boolean): Unit = {
       val message =
-        s"to handle MapPartitionEnd for ${Utils.makeMapKey(appId, shuffleId, 
mapId, attemptId)}, " +
+        s"to handle MapPartitionEnd for ${Utils.makeMapKey(applicationId, 
shuffleId, mapId, attemptId)}, " +
           s"$partitionId.";
       result match {
         case true => // if already committed by another try
@@ -725,9 +733,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
     }
 
     val dataCommitSuccess = commitManager.finalPartitionCommit(
-      applicationId,
       shuffleId,
-      reducerFileGroupsMap.get(shuffleId),
       partitionId)
     reply(dataCommitSuccess)
   }
@@ -737,12 +743,12 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       shuffleId: Int): Unit = {
     if (getPartitionType(shuffleId) == PartitionType.REDUCE) {
       // if StageEnd has not been handled, trigger StageEnd
-      if (!commitManager.stageEndShuffleSet.contains(shuffleId)) {
+      if (!commitManager.isStageEnd(shuffleId)) {
         logInfo(s"Call StageEnd before Unregister Shuffle $shuffleId.")
         handleStageEnd(appId, shuffleId)
         var timeout = stageEndTimeout
         val delta = 100
-        while (!commitManager.stageEndShuffleSet.contains(shuffleId) && 
timeout > 0) {
+        while (!commitManager.isStageEnd(shuffleId) && timeout > 0) {
           Thread.sleep(delta)
           timeout = timeout - delta
         }
@@ -790,7 +796,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       applicationId: String,
       shuffleId: Int,
       slots: WorkerResource): util.List[WorkerInfo] = {
-    val reserveSlotFailedWorkers = new ConcurrentHashMap[WorkerInfo, 
(StatusCode, Long)]()
+    val reserveSlotFailedWorkers = new ShuffleFailedWorkers()
     val failureInfos = new util.concurrent.CopyOnWriteArrayList[String]()
     val parallelism = Math.min(Math.max(1, slots.size()), 
conf.rpcMaxParallelism)
     ThreadUtils.parmap(slots.asScala.to, "ReserveSlot", parallelism) {
@@ -1158,7 +1164,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
             case _ => false
           }
         }.asJava
-      val reservedBlackList = new ConcurrentHashMap[WorkerInfo, (StatusCode, 
Long)]()
+      val reservedBlackList = new ShuffleFailedWorkers()
       reservedBlackList.putAll(reserved)
       blacklist.clear()
       blacklist.putAll(
@@ -1266,8 +1272,8 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
     }
   }
 
-  def recordWorkerFailure(failures: ConcurrentHashMap[WorkerInfo, (StatusCode, 
Long)]): Unit = {
-    val failedWorker = new ConcurrentHashMap[WorkerInfo, (StatusCode, 
Long)](failures)
+  def recordWorkerFailure(failures: ShuffleFailedWorkers): Unit = {
+    val failedWorker = new ShuffleFailedWorkers(failures)
     logInfo(s"Report Worker Failure: ${failedWorker.asScala}, current 
blacklist $blacklist")
     failedWorker.asScala.foreach { case (worker, (statusCode, registerTime)) =>
       if (!blacklist.containsKey(worker)) {
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
new file mode 100644
index 00000000..7a4aad30
--- /dev/null
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -0,0 +1,370 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client.commit
+
+import java.util
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.{AtomicLong, LongAdder}
+
+import scala.collection.JavaConverters._
+
+import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
+import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, 
ShuffleFailedWorkers, ShuffleFileGroups, ShuffleMapperAttempts}
+import org.apache.celeborn.client.ShuffleCommittedInfo
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
+import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
+import 
org.apache.celeborn.common.protocol.message.ControlMessages.{CommitFiles, 
CommitFilesResponse}
+import org.apache.celeborn.common.protocol.message.StatusCode
+import org.apache.celeborn.common.rpc.RpcEndpointRef
+import org.apache.celeborn.common.util.{ThreadUtils, Utils}
+// Can Remove this if celeborn don't support scala211 in future
+import org.apache.celeborn.common.util.FunctionConverter._
+
+case class CommitResult(
+    masterPartitionLocationMap: ConcurrentHashMap[String, PartitionLocation],
+    slavePartitionLocationMap: ConcurrentHashMap[String, PartitionLocation],
+    commitFilesFailedWorkers: ShuffleFailedWorkers)
+
+abstract class CommitHandler(
+    appId: String,
+    conf: CelebornConf,
+    allocatedWorkers: ShuffleAllocatedWorkers,
+    reducerFileGroupsMap: ShuffleFileGroups,
+    committedPartitionInfo: CommittedPartitionInfo) extends Logging {
+
+  private val pushReplicateEnabled = conf.pushReplicateEnabled
+  private val testRetryCommitFiles = conf.testRetryCommitFiles
+  private val commitEpoch = new AtomicLong()
+  private val totalWritten = new LongAdder
+  private val fileCount = new LongAdder
+
+  def getPartitionType(): PartitionType
+
+  def isStageEnd(shuffleId: Int): Boolean = false
+
+  def isStageEndOrInProcess(shuffleId: Int): Boolean = false
+
+  def isStageDataLost(shuffleId: Int): Boolean = false
+
+  def setStageEnd(shuffleId: Int): Unit
+
+  def isPartitionInProcess(shuffleId: Int, partitionId: Int): Boolean = false
+
+  /**
+   * when someone calls tryFinalCommit, the function will return true if there 
is no one ever do final commit before,
+   * otherwise it will return false.
+   * @return
+   */
+  def tryFinalCommit(
+      shuffleId: Int,
+      recordWorkerFailure: ShuffleFailedWorkers => Unit): Boolean
+
+  def finalPartitionCommit(
+      shuffleId: Int,
+      partitionId: Int,
+      recordWorkerFailure: ShuffleFailedWorkers => Unit): Boolean
+
+  def removeExpiredShuffle(shuffleId: Int): Unit
+
+  def getShuffleMapperAttempts(shuffleId: Int): Array[Int]
+
+  def parallelCommitFiles(
+      shuffleId: Int,
+      allocatedWorkers: util.Map[WorkerInfo, PartitionLocationInfo],
+      partitionIdOpt: Option[Int] = None): CommitResult = {
+    val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
+    val masterPartMap = new ConcurrentHashMap[String, PartitionLocation]
+    val slavePartMap = new ConcurrentHashMap[String, PartitionLocation]
+    val commitFilesFailedWorkers = new ShuffleFailedWorkers()
+    val commitFileStartTime = System.nanoTime()
+    val parallelism = Math.min(allocatedWorkers.size(), conf.rpcMaxParallelism)
+    ThreadUtils.parmap(
+      allocatedWorkers.asScala.to,
+      "CommitFiles",
+      parallelism) { case (worker, partitionLocationInfo) =>
+      if (partitionLocationInfo.containsShuffle(shuffleId.toString)) {
+        val masterParts =
+          partitionLocationInfo.getMasterLocations(shuffleId.toString, 
partitionIdOpt)
+        val slaveParts = 
partitionLocationInfo.getSlaveLocations(shuffleId.toString, partitionIdOpt)
+        masterParts.asScala.foreach { p =>
+          val partition = new PartitionLocation(p)
+          partition.setFetchPort(worker.fetchPort)
+          partition.setPeer(null)
+          masterPartMap.put(partition.getUniqueId, partition)
+        }
+        slaveParts.asScala.foreach { p =>
+          val partition = new PartitionLocation(p)
+          partition.setFetchPort(worker.fetchPort)
+          partition.setPeer(null)
+          slavePartMap.put(partition.getUniqueId, partition)
+        }
+
+        val (masterIds, slaveIds) = shuffleCommittedInfo.synchronized {
+          (
+            masterParts.asScala
+              
.filterNot(shuffleCommittedInfo.handledCommitPartitionRequests.contains)
+              .map(_.getUniqueId).asJava,
+            slaveParts.asScala
+              
.filterNot(shuffleCommittedInfo.handledCommitPartitionRequests.contains)
+              .map(_.getUniqueId).asJava)
+        }
+
+        commitFiles(
+          appId,
+          shuffleId,
+          shuffleCommittedInfo,
+          worker,
+          masterIds,
+          slaveIds,
+          commitFilesFailedWorkers)
+      }
+    }
+
+    logInfo(s"Shuffle $shuffleId " +
+      s"commit files complete. File count 
${shuffleCommittedInfo.currentShuffleFileCount.sum()} " +
+      s"using ${(System.nanoTime() - commitFileStartTime) / 1000000} ms")
+
+    CommitResult(masterPartMap, slavePartMap, commitFilesFailedWorkers)
+  }
+
+  def commitFiles(
+      applicationId: String,
+      shuffleId: Int,
+      shuffleCommittedInfo: ShuffleCommittedInfo,
+      worker: WorkerInfo,
+      masterIds: util.List[String],
+      slaveIds: util.List[String],
+      commitFilesFailedWorkers: ShuffleFailedWorkers): Unit = {
+
+    val res =
+      if (!testRetryCommitFiles) {
+        val commitFiles = CommitFiles(
+          applicationId,
+          shuffleId,
+          masterIds,
+          slaveIds,
+          getShuffleMapperAttempts(shuffleId),
+          commitEpoch.incrementAndGet())
+        val res = requestCommitFilesWithRetry(worker.endpoint, commitFiles)
+
+        res.status match {
+          case StatusCode.SUCCESS => // do nothing
+          case StatusCode.PARTIAL_SUCCESS | StatusCode.SHUFFLE_NOT_REGISTERED 
| StatusCode.FAILED =>
+            logDebug(s"Request $commitFiles return ${res.status} for " +
+              s"${Utils.makeShuffleKey(applicationId, shuffleId)}")
+            commitFilesFailedWorkers.put(worker, (res.status, 
System.currentTimeMillis()))
+          case _ => // won't happen
+        }
+        res
+      } else {
+        // for test
+        val commitFiles1 = CommitFiles(
+          applicationId,
+          shuffleId,
+          masterIds.subList(0, masterIds.size() / 2),
+          slaveIds.subList(0, slaveIds.size() / 2),
+          getShuffleMapperAttempts(shuffleId),
+          commitEpoch.incrementAndGet())
+        val res1 = requestCommitFilesWithRetry(worker.endpoint, commitFiles1)
+
+        val commitFiles = CommitFiles(
+          applicationId,
+          shuffleId,
+          masterIds.subList(masterIds.size() / 2, masterIds.size()),
+          slaveIds.subList(slaveIds.size() / 2, slaveIds.size()),
+          getShuffleMapperAttempts(shuffleId),
+          commitEpoch.incrementAndGet())
+        val res2 = requestCommitFilesWithRetry(worker.endpoint, commitFiles)
+
+        
res1.committedMasterStorageInfos.putAll(res2.committedMasterStorageInfos)
+        res1.committedSlaveStorageInfos.putAll(res2.committedSlaveStorageInfos)
+        res1.committedMapIdBitMap.putAll(res2.committedMapIdBitMap)
+        CommitFilesResponse(
+          status = if (res1.status == StatusCode.SUCCESS) res2.status else 
res1.status,
+          (res1.committedMasterIds.asScala ++ 
res2.committedMasterIds.asScala).toList.asJava,
+          (res1.committedSlaveIds.asScala ++ 
res1.committedSlaveIds.asScala).toList.asJava,
+          (res1.failedMasterIds.asScala ++ 
res1.failedMasterIds.asScala).toList.asJava,
+          (res1.failedSlaveIds.asScala ++ 
res2.failedSlaveIds.asScala).toList.asJava,
+          res1.committedMasterStorageInfos,
+          res1.committedSlaveStorageInfos,
+          res1.committedMapIdBitMap,
+          res1.totalWritten + res2.totalWritten,
+          res1.fileCount + res2.fileCount)
+      }
+
+    shuffleCommittedInfo.synchronized {
+      // record committed partitionIds
+      res.committedMasterIds.asScala.foreach({
+        case commitMasterId =>
+          val partitionUniqueIdList = 
shuffleCommittedInfo.committedMasterIds.computeIfAbsent(
+            Utils.splitPartitionLocationUniqueId(commitMasterId)._1,
+            (k: Int) => new util.ArrayList[String]())
+          partitionUniqueIdList.add(commitMasterId)
+      })
+
+      res.committedSlaveIds.asScala.foreach({
+        case commitSlaveId =>
+          val partitionUniqueIdList = 
shuffleCommittedInfo.committedSlaveIds.computeIfAbsent(
+            Utils.splitPartitionLocationUniqueId(commitSlaveId)._1,
+            (k: Int) => new util.ArrayList[String]())
+          partitionUniqueIdList.add(commitSlaveId)
+      })
+
+      // record committed partitions storage hint and disk hint
+      
shuffleCommittedInfo.committedMasterStorageInfos.putAll(res.committedMasterStorageInfos)
+      
shuffleCommittedInfo.committedSlaveStorageInfos.putAll(res.committedSlaveStorageInfos)
+
+      // record failed partitions
+      shuffleCommittedInfo.failedMasterPartitionIds.putAll(
+        res.failedMasterIds.asScala.map((_, worker)).toMap.asJava)
+      shuffleCommittedInfo.failedSlavePartitionIds.putAll(
+        res.failedSlaveIds.asScala.map((_, worker)).toMap.asJava)
+
+      
shuffleCommittedInfo.committedMapIdBitmap.putAll(res.committedMapIdBitMap)
+
+      totalWritten.add(res.totalWritten)
+      fileCount.add(res.fileCount)
+      shuffleCommittedInfo.currentShuffleFileCount.add(res.fileCount)
+    }
+  }
+
+  def collectResult(
+      shuffleId: Int,
+      shuffleCommittedInfo: ShuffleCommittedInfo,
+      masterPartitionUniqueIds: util.Iterator[String],
+      slavePartitionUniqueIds: util.Iterator[String],
+      masterPartMap: ConcurrentHashMap[String, PartitionLocation],
+      slavePartMap: ConcurrentHashMap[String, PartitionLocation]): Unit = {
+    val committedPartitions = new util.HashMap[String, PartitionLocation]
+    masterPartitionUniqueIds.asScala.foreach { id =>
+      if (shuffleCommittedInfo.committedMasterStorageInfos.get(id) == null) {
+        logDebug(s"$appId-$shuffleId $id storage hint was not returned")
+      } else {
+        masterPartMap.get(id).setStorageInfo(
+          shuffleCommittedInfo.committedMasterStorageInfos.get(id))
+        
masterPartMap.get(id).setMapIdBitMap(shuffleCommittedInfo.committedMapIdBitmap.get(id))
+        committedPartitions.put(id, masterPartMap.get(id))
+      }
+    }
+
+    slavePartitionUniqueIds.asScala.foreach { id =>
+      val slavePartition = slavePartMap.get(id)
+      if (shuffleCommittedInfo.committedSlaveStorageInfos.get(id) == null) {
+        logDebug(s"$appId-$shuffleId $id storage hint was not returned")
+      } else {
+        
slavePartition.setStorageInfo(shuffleCommittedInfo.committedSlaveStorageInfos.get(id))
+        val masterPartition = committedPartitions.get(id)
+        if (masterPartition ne null) {
+          masterPartition.setPeer(slavePartition)
+          slavePartition.setPeer(masterPartition)
+        } else {
+          logInfo(s"Shuffle $shuffleId partition $id: master lost, " +
+            s"use slave $slavePartition.")
+          
slavePartition.setMapIdBitMap(shuffleCommittedInfo.committedMapIdBitmap.get(id))
+          committedPartitions.put(id, slavePartition)
+        }
+      }
+    }
+
+    committedPartitions.values().asScala.foreach { partition =>
+      val partitionLocations = 
reducerFileGroupsMap.get(shuffleId).computeIfAbsent(
+        partition.getId,
+        (k: Integer) => new util.HashSet[PartitionLocation]())
+      partitionLocations.add(partition)
+    }
+  }
+
+  private def requestCommitFilesWithRetry(
+      endpoint: RpcEndpointRef,
+      message: CommitFiles): CommitFilesResponse = {
+    val maxRetries = conf.requestCommitFilesMaxRetries
+    var retryTimes = 0
+    while (retryTimes < maxRetries) {
+      try {
+        if (testRetryCommitFiles && retryTimes < maxRetries - 1) {
+          endpoint.ask[CommitFilesResponse](message)
+          Thread.sleep(1000)
+          throw new Exception("Mock fail for CommitFiles")
+        } else {
+          return endpoint.askSync[CommitFilesResponse](message)
+        }
+      } catch {
+        case e: Throwable =>
+          retryTimes += 1
+          logError(
+            s"AskSync CommitFiles for ${message.shuffleId} failed (attempt 
$retryTimes/$maxRetries).",
+            e)
+      }
+    }
+
+    CommitFilesResponse(
+      StatusCode.FAILED,
+      List.empty.asJava,
+      List.empty.asJava,
+      message.masterIds,
+      message.slaveIds)
+  }
+
+  def checkDataLost(
+      shuffleId: Int,
+      masterPartitionUniqueIdMap: util.Map[String, WorkerInfo],
+      slavePartitionUniqueIdMap: util.Map[String, WorkerInfo]): Boolean = {
+    val shuffleKey = Utils.makeShuffleKey(appId, shuffleId)
+    if (!pushReplicateEnabled && masterPartitionUniqueIdMap.size() != 0) {
+      val msg =
+        masterPartitionUniqueIdMap.asScala.map {
+          case (partitionUniqueId, workerInfo) =>
+            s"Lost partition $partitionUniqueId in worker 
[${workerInfo.readableAddress()}]"
+        }.mkString("\n")
+      logError(
+        s"""
+           |For shuffle $shuffleKey partition data lost:
+           |$msg
+           |""".stripMargin)
+      true
+    } else {
+      val failedBothPartitionIdsToWorker = 
masterPartitionUniqueIdMap.asScala.flatMap {
+        case (partitionUniqueId, worker) =>
+          if (slavePartitionUniqueIdMap.asScala.contains(partitionUniqueId)) {
+            Some(partitionUniqueId -> (worker, 
slavePartitionUniqueIdMap.get(partitionUniqueId)))
+          } else {
+            None
+          }
+      }
+      if (failedBothPartitionIdsToWorker.nonEmpty) {
+        val msg = failedBothPartitionIdsToWorker.map {
+          case (partitionUniqueId, (masterWorker, slaveWorker)) =>
+            s"Lost partition $partitionUniqueId " +
+              s"in master worker [${masterWorker.readableAddress()}] and slave 
worker [$slaveWorker]"
+        }.mkString("\n")
+        logError(
+          s"""
+             |For shuffle $shuffleKey partition data lost:
+             |$msg
+             |""".stripMargin)
+        true
+      } else {
+        false
+      }
+    }
+  }
+
+  def commitMetrics(): (Long, Long) = (totalWritten.sumThenReset(), 
fileCount.sumThenReset())
+}
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
new file mode 100644
index 00000000..2b25080b
--- /dev/null
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client.commit
+
+import java.util
+import java.util.Collections
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.JavaConverters._
+
+import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
+import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, 
ShuffleFailedWorkers, ShuffleFileGroups}
+import org.apache.celeborn.client.ShuffleCommittedInfo
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
+import org.apache.celeborn.common.protocol.PartitionType
+// Can Remove this if celeborn don't support scala211 in future
+import org.apache.celeborn.common.util.FunctionConverter._
+import org.apache.celeborn.common.util.Utils
+
+/**
+ * This commit handler is for MapPartition ShuffleType, which means that a Map 
Partition contains all data produced
+ * by an upstream MapTask, and data in a Map Partition may be consumed by 
multiple ReduceTasks. If the upstream MapTask
+ * has multiple outputs, each will be a Map Partition.
+ *
+ * @see [[org.apache.celeborn.common.protocol.PartitionType.MAP]]
+ */
+class MapPartitionCommitHandler(
+    appId: String,
+    conf: CelebornConf,
+    allocatedWorkers: ShuffleAllocatedWorkers,
+    reducerFileGroupsMap: ShuffleFileGroups,
+    committedPartitionInfo: CommittedPartitionInfo)
+  extends CommitHandler(appId, conf, allocatedWorkers, reducerFileGroupsMap, 
committedPartitionInfo)
+  with Logging {
+
+  // shuffleId -> in processing partitionId set
+  private val inProcessMapPartitionEndIds = new ConcurrentHashMap[Int, 
util.Set[Int]]()
+
+  override def getPartitionType(): PartitionType = {
+    PartitionType.MAP
+  }
+
+  override def setStageEnd(shuffleId: Int): Unit = {
+    throw new UnsupportedOperationException(
+      "Failed when do setStageEnd Operation, MapPartition shuffleType don't " +
+        "support set stage end")
+  }
+
+  override def isPartitionInProcess(shuffleId: Int, partitionId: Int): Boolean 
= {
+    inProcessMapPartitionEndIds.containsKey(shuffleId) && 
inProcessMapPartitionEndIds.get(
+      shuffleId).contains(partitionId)
+  }
+
+  override def tryFinalCommit(
+      shuffleId: Int,
+      recordWorkerFailure: ShuffleFailedWorkers => Unit): Boolean = {
+    throw new UnsupportedOperationException(
+      "Failed when do final Commit Operation, MapPartition shuffleType only " +
+        "support final partition Commit")
+  }
+
+  override def finalPartitionCommit(
+      shuffleId: Int,
+      partitionId: Int,
+      recordWorkerFailure: ShuffleFailedWorkers => Unit): Boolean = {
+    val inProcessingPartitionIds =
+      inProcessMapPartitionEndIds.computeIfAbsent(shuffleId, (k: Int) => new 
util.HashSet[Int]())
+    inProcessingPartitionIds.add(partitionId)
+
+    val partitionAllocatedWorkers = 
allocatedWorkers.get(shuffleId).asScala.filter(p =>
+      p._2.containsPartition(shuffleId.toString, partitionId)).asJava
+
+    var dataCommitSuccess = true
+    if (!partitionAllocatedWorkers.isEmpty) {
+      val result =
+        handleFinalPartitionCommitFiles(
+          shuffleId,
+          partitionAllocatedWorkers,
+          partitionId)
+      dataCommitSuccess = result._1
+      recordWorkerFailure(result._2)
+    }
+
+    // release resources and clear related info
+    partitionAllocatedWorkers.asScala.foreach { case (_, 
partitionLocationInfo) =>
+      partitionLocationInfo.removeRelatedPartitions(shuffleId.toString, 
partitionId)
+    }
+
+    inProcessingPartitionIds.remove(partitionId)
+    dataCommitSuccess
+  }
+
+  override def getShuffleMapperAttempts(shuffleId: Int): Array[Int] = {
+    // map partition now return empty mapper attempts array as map partition 
don't prevent other mapper commit file
+    // even the same mapper id with another attemptId success in lifecycle 
manager.
+    Array.empty
+  }
+
+  override def removeExpiredShuffle(shuffleId: Int): Unit = {
+    inProcessMapPartitionEndIds.remove(shuffleId)
+  }
+
+  private def handleFinalPartitionCommitFiles(
+      shuffleId: Int,
+      allocatedWorkers: util.Map[WorkerInfo, PartitionLocationInfo],
+      partitionId: Int): (Boolean, ShuffleFailedWorkers) = {
+    val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
+    // commit files
+    val parallelCommitResult = parallelCommitFiles(shuffleId, 
allocatedWorkers, Some(partitionId))
+
+    // check map partition inflight request complete
+    waitInflightRequestComplete(shuffleCommittedInfo, partitionId)
+
+    // check partition data lost
+    val failedMasterPartitionUniqueIds =
+      getPartitionIds(shuffleCommittedInfo.failedMasterPartitionIds, 
partitionId)
+    val failedSlavePartitionUniqueIds =
+      getPartitionIds(shuffleCommittedInfo.failedSlavePartitionIds, 
partitionId)
+    val dataLost =
+      checkDataLost(shuffleId, failedMasterPartitionUniqueIds, 
failedSlavePartitionUniqueIds)
+
+    // collect partition result
+    if (!dataLost) {
+      collectResult(
+        shuffleId,
+        shuffleCommittedInfo,
+        getPartitionUniqueIds(shuffleCommittedInfo.committedMasterIds, 
partitionId),
+        getPartitionUniqueIds(shuffleCommittedInfo.committedSlaveIds, 
partitionId),
+        parallelCommitResult.masterPartitionLocationMap,
+        parallelCommitResult.slavePartitionLocationMap)
+    }
+
+    (dataLost, parallelCommitResult.commitFilesFailedWorkers)
+  }
+
+  private def waitInflightRequestComplete(
+      shuffleCommittedInfo: ShuffleCommittedInfo,
+      partitionId: Int): Unit = {
+    if 
(shuffleCommittedInfo.partitionInFlightCommitRequestNum.containsKey(partitionId))
 {
+      while (shuffleCommittedInfo.partitionInFlightCommitRequestNum.get(
+          partitionId).get() > 0) {
+        Thread.sleep(1000)
+      }
+    }
+  }
+
+  private def getPartitionIds(
+      partitionIds: ConcurrentHashMap[String, WorkerInfo],
+      partitionId: Int): util.Map[String, WorkerInfo] = {
+    partitionIds.asScala.filter(p =>
+      Utils.splitPartitionLocationUniqueId(p._1)._1 ==
+        partitionId).asJava
+  }
+
+  private def getPartitionUniqueIds(
+      ids: ConcurrentHashMap[Int, util.List[String]],
+      partitionId: Int): util.Iterator[String] = {
+    ids.getOrDefault(partitionId, Collections.emptyList[String]).iterator()
+  }
+}
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
new file mode 100644
index 00000000..34e6758e
--- /dev/null
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client.commit
+
+import java.util
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.JavaConverters._
+
+import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
+import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, 
ShuffleFailedWorkers, ShuffleFileGroups, ShuffleMapperAttempts}
+import org.apache.celeborn.client.ShuffleCommittedInfo
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
+import org.apache.celeborn.common.protocol.PartitionType
+
+/**
+ * This commit handler is for ReducePartition ShuffleType, which means that a 
Reduce Partition contains all data
+ * produced by all upstream MapTasks, and data in a Reduce Partition would 
only be consumed by one ReduceTask. If the
+ * ReduceTask has multiple inputs, each will be a ReducePartition
+ *
+ * @see [[org.apache.celeborn.common.protocol.PartitionType.REDUCE]]
+ */
+class ReducePartitionCommitHandler(
+    appId: String,
+    conf: CelebornConf,
+    allocatedWorkers: ShuffleAllocatedWorkers,
+    reducerFileGroupsMap: ShuffleFileGroups,
+    committedPartitionInfo: CommittedPartitionInfo,
+    shuffleMapperAttempts: ShuffleMapperAttempts)
+  extends CommitHandler(appId, conf, allocatedWorkers, reducerFileGroupsMap, 
committedPartitionInfo)
+  with Logging {
+
+  private val dataLostShuffleSet = ConcurrentHashMap.newKeySet[Int]()
+  private val stageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
+  private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
+
+  override def getPartitionType(): PartitionType = {
+    PartitionType.REDUCE
+  }
+
+  override def isStageEnd(shuffleId: Int): Boolean = {
+    stageEndShuffleSet.contains(shuffleId)
+  }
+
+  override def isStageEndOrInProcess(shuffleId: Int): Boolean = {
+    inProcessStageEndShuffleSet.contains(shuffleId) ||
+    stageEndShuffleSet.contains(shuffleId)
+  }
+
+  override def isStageDataLost(shuffleId: Int): Boolean = {
+    dataLostShuffleSet.contains(shuffleId)
+  }
+
+  override def isPartitionInProcess(shuffleId: Int, partitionId: Int): Boolean 
= {
+    isStageEndOrInProcess(shuffleId)
+  }
+
+  override def setStageEnd(shuffleId: Int): Unit = {
+    stageEndShuffleSet.add(shuffleId)
+  }
+
+  def removeExpiredShuffle(shuffleId: Int): Unit = {
+    dataLostShuffleSet.remove(shuffleId)
+    stageEndShuffleSet.remove(shuffleId)
+    inProcessStageEndShuffleSet.remove(shuffleId)
+  }
+
+  override def tryFinalCommit(
+      shuffleId: Int,
+      recordWorkerFailure: ShuffleFailedWorkers => Unit): Boolean = {
+    if (this.isStageEnd(shuffleId)) {
+      logInfo(s"[handleStageEnd] Shuffle $shuffleId already ended!")
+      return false
+    } else {
+      inProcessStageEndShuffleSet.synchronized {
+        if (inProcessStageEndShuffleSet.contains(shuffleId)) {
+          logWarning(s"[handleStageEnd] Shuffle $shuffleId is in process!")
+          return false
+        } else {
+          inProcessStageEndShuffleSet.add(shuffleId)
+        }
+      }
+    }
+
+    // ask allLocations workers holding partitions to commit files
+    val shuffleAllocatedWorkers = allocatedWorkers.get(shuffleId)
+    val (dataLost, commitFailedWorkers) = handleFinalCommitFiles(shuffleId, 
shuffleAllocatedWorkers)
+    recordWorkerFailure(commitFailedWorkers)
+    // reply
+    if (!dataLost) {
+      logInfo(s"Succeed to handle stageEnd for $shuffleId.")
+      // record in stageEndShuffleSet
+      stageEndShuffleSet.add(shuffleId)
+    } else {
+      logError(s"Failed to handle stageEnd for $shuffleId, lost file!")
+      dataLostShuffleSet.add(shuffleId)
+      // record in stageEndShuffleSet
+      stageEndShuffleSet.add(shuffleId)
+    }
+    inProcessStageEndShuffleSet.remove(shuffleId)
+    true
+  }
+
+  private def handleFinalCommitFiles(
+      shuffleId: Int,
+      allocatedWorkers: util.Map[WorkerInfo, PartitionLocationInfo])
+      : (Boolean, ShuffleFailedWorkers) = {
+    val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
+
+    // commit files
+    val parallelCommitResult = parallelCommitFiles(shuffleId, 
allocatedWorkers, None)
+
+    // check all inflight request complete
+    waitInflightRequestComplete(shuffleCommittedInfo)
+
+    // check data lost
+    val dataLost = checkDataLost(
+      shuffleId,
+      shuffleCommittedInfo.failedMasterPartitionIds,
+      shuffleCommittedInfo.failedSlavePartitionIds)
+
+    // collect result
+    if (!dataLost) {
+      collectResult(
+        shuffleId,
+        shuffleCommittedInfo,
+        getPartitionUniqueIds(shuffleCommittedInfo.committedMasterIds),
+        getPartitionUniqueIds(shuffleCommittedInfo.committedSlaveIds),
+        parallelCommitResult.masterPartitionLocationMap,
+        parallelCommitResult.slavePartitionLocationMap)
+    }
+
+    (dataLost, parallelCommitResult.commitFilesFailedWorkers)
+  }
+
+  override def finalPartitionCommit(
+      shuffleId: Int,
+      partitionId: Int,
+      recordWorkerFailure: ShuffleFailedWorkers => Unit): Boolean = {
+    throw new UnsupportedOperationException(
+      s"Failed when do final Partition Commit Operation, Reduce Partition " +
+        s"shuffleType only Support final commit for all partitions ")
+  }
+
+  override def getShuffleMapperAttempts(shuffleId: Int): Array[Int] = {
+    shuffleMapperAttempts.get(shuffleId)
+  }
+
+  private def waitInflightRequestComplete(shuffleCommittedInfo: 
ShuffleCommittedInfo): Unit = {
+    while (shuffleCommittedInfo.allInFlightCommitRequestNum.get() > 0) {
+      Thread.sleep(1000)
+    }
+  }
+
+  private def getPartitionUniqueIds(ids: ConcurrentHashMap[Int, 
util.List[String]])
+      : util.Iterator[String] = {
+    ids.asScala.flatMap(_._2.asScala).toIterator.asJava
+  }
+}
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala
 
b/common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala
index 78550907..e23d2ece 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala
@@ -46,16 +46,7 @@ class PartitionLocationInfo extends Logging {
     slavePartitionLocations.containsKey(shuffleKey)
   }
 
-  def containsRelatedShuffleOrPartition(shuffleKey: String, partitionIdOpt: 
Option[Int]): Boolean =
-    this.synchronized {
-      partitionIdOpt match {
-        case Some(partitionId) =>
-          containsPartition(shuffleKey, partitionId)
-        case None => containsShuffle(shuffleKey)
-      }
-    }
-
-  private def containsPartition(shuffleKey: String, partitionId: Int): Boolean 
= this
+  def containsPartition(shuffleKey: String, partitionId: Int): Boolean = this
     .synchronized {
       val contain = masterPartitionLocations.containsKey(
         shuffleKey) && 
masterPartitionLocations.get(shuffleKey).containsKey(partitionId)
@@ -237,21 +228,15 @@ class PartitionLocationInfo extends Logging {
     }
   }
 
-  def removeAllRelatedPartitions(
+  def removeRelatedPartitions(
       shuffleKey: String,
-      partitionIdOpt: Option[Int]): Unit = this
+      partitionId: Int): Unit = this
     .synchronized {
-      partitionIdOpt match {
-        case Some(partitionId) =>
-          if (masterPartitionLocations.containsKey(shuffleKey)) {
-            masterPartitionLocations.get(shuffleKey).remove(partitionId)
-          }
-          if (slavePartitionLocations.containsKey(shuffleKey)) {
-            slavePartitionLocations.get(shuffleKey).remove(partitionId)
-          }
-        case None =>
-          removeMasterPartitions(shuffleKey)
-          removeSlavePartitions(shuffleKey)
+      if (masterPartitionLocations.containsKey(shuffleKey)) {
+        masterPartitionLocations.get(shuffleKey).remove(partitionId)
+      }
+      if (slavePartitionLocations.containsKey(shuffleKey)) {
+        slavePartitionLocations.get(shuffleKey).remove(partitionId)
       }
     }
 

Reply via email to