This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new e2196e93 [CELEBORN-56] [ISSUE-945] handle map partition mapper end 
(#1003)
e2196e93 is described below

commit e2196e93830fdd90bdb3861d513942cb48723f0e
Author: Shuang <[email protected]>
AuthorDate: Wed Dec 7 21:09:02 2022 +0800

    [CELEBORN-56] [ISSUE-945] handle map partition mapper end (#1003)
---
 .../org/apache/celeborn/client/ShuffleClient.java  |  12 +-
 .../apache/celeborn/client/ShuffleClientImpl.java  |  46 ++-
 .../org/apache/celeborn/client/CommitManager.scala | 348 +++++++++++++++------
 .../apache/celeborn/client/LifecycleManager.scala  | 140 ++++++---
 .../apache/celeborn/client/DummyShuffleClient.java |  10 +
 common/src/main/proto/TransportMessages.proto      |   6 +-
 .../common/meta/PartitionLocationInfo.scala        |  88 +++++-
 .../common/protocol/message/ControlMessages.scala  |  36 ++-
 .../celeborn/common/util/FunctionConverter.scala   |  32 ++
 .../apache/celeborn/common/util/UtilsSuite.scala   |  46 +++
 .../celeborn/tests/client/ShuffleClientSuite.scala |  41 ++-
 .../apache/celeborn/tests/spark/HugeDataTest.scala |  12 +-
 12 files changed, 629 insertions(+), 188 deletions(-)

diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index b9fcb9e1..3b76333f 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -149,11 +149,21 @@ public abstract class ShuffleClient implements Cloneable {
   public abstract void pushMergedData(String applicationId, int shuffleId, int 
mapId, int attemptId)
       throws IOException;
 
-  // Report partition locations written by the completed map task
+  // Report partition locations written by the completed map task of 
ReducePartition Shuffle Type
   public abstract void mapperEnd(
       String applicationId, int shuffleId, int mapId, int attemptId, int 
numMappers)
       throws IOException;
 
+  // Report partition locations written by the completed map task of 
MapPartition Shuffle Type
+  public abstract void mapPartitionMapperEnd(
+      String applicationId,
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int numMappers,
+      int partitionId)
+      throws IOException;
+
   // Cleanup states of the map task
   public abstract void cleanup(String applicationId, int shuffleId, int mapId, 
int attemptId);
 
diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 67a6ee4e..a3d0c6b1 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -113,10 +113,10 @@ public class ShuffleClientImpl extends ShuffleClient {
       };
 
   private static class ReduceFileGroups {
-    final PartitionLocation[][] partitionGroups;
+    final Map<Integer, Set<PartitionLocation>> partitionGroups;
     final int[] mapAttempts;
 
-    ReduceFileGroups(PartitionLocation[][] partitionGroups, int[] mapAttempts) 
{
+    ReduceFileGroups(Map<Integer, Set<PartitionLocation>> partitionGroups, 
int[] mapAttempts) {
       this.partitionGroups = partitionGroups;
       this.mapAttempts = mapAttempts;
     }
@@ -1028,6 +1028,29 @@ public class ShuffleClientImpl extends ShuffleClient {
   public void mapperEnd(
       String applicationId, int shuffleId, int mapId, int attemptId, int 
numMappers)
       throws IOException {
+    mapEndInternal(applicationId, shuffleId, mapId, attemptId, numMappers, -1);
+  }
+
+  @Override
+  public void mapPartitionMapperEnd(
+      String applicationId,
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int numMappers,
+      int partitionId)
+      throws IOException {
+    mapEndInternal(applicationId, shuffleId, mapId, attemptId, numMappers, 
partitionId);
+  }
+
+  private void mapEndInternal(
+      String applicationId,
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int numMappers,
+      Integer partitionId)
+      throws IOException {
     final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
     PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new 
PushState(conf));
 
@@ -1036,7 +1059,7 @@ public class ShuffleClientImpl extends ShuffleClient {
 
       MapperEndResponse response =
           driverRssMetaService.askSync(
-              new MapperEnd(applicationId, shuffleId, mapId, attemptId, 
numMappers),
+              new MapperEnd(applicationId, shuffleId, mapId, attemptId, 
numMappers, partitionId),
               ClassTag$.MODULE$.apply(MapperEndResponse.class));
       if (response.status() != StatusCode.SUCCESS) {
         throw new IOException("MapperEnd failed! StatusCode: " + 
response.status());
@@ -1119,9 +1142,10 @@ public class ShuffleClientImpl extends ShuffleClient {
 
                 if (response.status() == StatusCode.SUCCESS) {
                   logger.info(
-                      "Shuffle {} request reducer file group success using 
time:{} ms",
+                      "Shuffle {} request reducer file group success using 
time:{} ms, result partition ids: {}",
                       shuffleId,
-                      (System.nanoTime() - getReducerFileGroupStartTime) / 
1000_000);
+                      (System.nanoTime() - getReducerFileGroupStartTime) / 
1000_000,
+                      response.fileGroup().keySet());
                   return new ReduceFileGroups(response.fileGroup(), 
response.attempts());
                 } else if (response.status() == StatusCode.STAGE_END_TIME_OUT) 
{
                   logger.warn(
@@ -1147,15 +1171,16 @@ public class ShuffleClientImpl extends ShuffleClient {
       String msg = "Shuffle data lost for shuffle " + shuffleId + " reduce " + 
partitionId + "!";
       logger.error(msg);
       throw new IOException(msg);
-    } else if (fileGroups.partitionGroups.length == 0) {
-      logger.warn("Shuffle data is empty for shuffle {} reduce {}.", 
shuffleId, partitionId);
+    } else if (fileGroups.partitionGroups.size() == 0
+        || !fileGroups.partitionGroups.containsKey(partitionId)) {
+      logger.warn("Shuffle data is empty for shuffle {} partitionId {}.", 
shuffleId, partitionId);
       return RssInputStream.empty();
     } else {
       return RssInputStream.create(
           conf,
           dataClientFactory,
           shuffleKey,
-          fileGroups.partitionGroups[partitionId],
+          fileGroups.partitionGroups.get(partitionId).toArray(new 
PartitionLocation[0]),
           fileGroups.mapAttempts,
           attemptNumber,
           startMapIndex,
@@ -1163,6 +1188,11 @@ public class ShuffleClientImpl extends ShuffleClient {
     }
   }
 
+  @VisibleForTesting
+  public Map<Integer, ReduceFileGroups> getReduceFileGroupsMap() {
+    return reduceFileGroupsMap;
+  }
+
   @Override
   public void shutDown() {
     if (null != rpcEnv) {
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index 9bd7aeb7..4bb2ffd9 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -29,11 +29,12 @@ import org.roaringbitmap.RoaringBitmap
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
-import org.apache.celeborn.common.protocol.{PartitionLocation, StorageInfo}
+import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType, 
StorageInfo}
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.{CommitFiles, 
CommitFilesResponse}
 import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.rpc.RpcEndpointRef
 import org.apache.celeborn.common.util.{ThreadUtils, Utils}
+import org.apache.celeborn.common.util.FunctionConverter._
 
 case class CommitPartitionRequest(
     applicationId: String,
@@ -41,8 +42,8 @@ case class CommitPartitionRequest(
     partition: PartitionLocation)
 
 case class ShuffleCommittedInfo(
-    committedMasterIds: util.List[String],
-    committedSlaveIds: util.List[String],
+    committedMasterIds: ConcurrentHashMap[Int, util.List[String]],
+    committedSlaveIds: ConcurrentHashMap[Int, util.List[String]],
     failedMasterPartitionIds: ConcurrentHashMap[String, WorkerInfo],
     failedSlavePartitionIds: ConcurrentHashMap[String, WorkerInfo],
     committedMasterStorageInfos: ConcurrentHashMap[String, StorageInfo],
@@ -51,7 +52,8 @@ case class ShuffleCommittedInfo(
     currentShuffleFileCount: LongAdder,
     commitPartitionRequests: util.Set[CommitPartitionRequest],
     handledCommitPartitionRequests: util.Set[PartitionLocation],
-    inFlightCommitRequest: AtomicInteger)
+    allInFlightCommitRequestNum: AtomicInteger,
+    partitionInFlightCommitRequestNum: ConcurrentHashMap[Int, AtomicInteger])
 
 class CommitManager(appId: String, val conf: CelebornConf, lifecycleManager: 
LifecycleManager)
   extends Logging {
@@ -60,7 +62,8 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
   val dataLostShuffleSet = ConcurrentHashMap.newKeySet[Int]()
   val stageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
   private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
-
+  // shuffleId -> in processing partitionId set
+  private val inProcessMapPartitionEndIds = new ConcurrentHashMap[Int, 
util.Set[Int]]()
   private val pushReplicateEnabled = conf.pushReplicateEnabled
 
   private val batchHandleCommitPartitionEnabled = 
conf.batchHandleCommitPartitionEnabled
@@ -92,6 +95,67 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
             committedPartitionInfo.asScala.foreach { case (shuffleId, 
shuffleCommittedInfo) =>
               batchHandleCommitPartitionExecutors.submit {
                 new Runnable {
+                  val partitionType = 
lifecycleManager.getPartitionType(shuffleId)
+                  def isPartitionInProcess(partitionId: Int): Boolean = {
+                    if (inProcessMapPartitionEndIds.containsKey(shuffleId) &&
+                      
inProcessMapPartitionEndIds.get(shuffleId).contains(partitionId)) {
+                      true
+                    } else {
+                      false
+                    }
+                  }
+
+                  def incrementInflightNum(workerToRequests: Map[
+                    WorkerInfo,
+                    collection.Set[PartitionLocation]]): Unit = {
+                    if (partitionType == PartitionType.MAP) {
+                      workerToRequests.foreach {
+                        case (_, partitions) =>
+                          partitions.groupBy(_.getId).foreach { case (id, _) =>
+                            val atomicInteger = shuffleCommittedInfo
+                              .partitionInFlightCommitRequestNum
+                              .computeIfAbsent(id, (k: Int) => new 
AtomicInteger(0))
+                            atomicInteger.incrementAndGet()
+                          }
+                      }
+                    }
+                    shuffleCommittedInfo.allInFlightCommitRequestNum.addAndGet(
+                      workerToRequests.size)
+                  }
+
+                  def decrementInflightNum(
+                      workerToRequests: Map[WorkerInfo, 
collection.Set[PartitionLocation]])
+                      : Unit = {
+                    if (partitionType == PartitionType.MAP) {
+                      workerToRequests.foreach {
+                        case (_, partitions) =>
+                          partitions.groupBy(_.getId).foreach { case (id, _) =>
+                            
shuffleCommittedInfo.partitionInFlightCommitRequestNum.get(
+                              id).decrementAndGet()
+                          }
+                      }
+                    }
+                    shuffleCommittedInfo.allInFlightCommitRequestNum.addAndGet(
+                      -workerToRequests.size)
+                  }
+
+                  def getUnCommitPartitionRequests(
+                      commitPartitionRequests: 
util.Set[CommitPartitionRequest])
+                      : scala.collection.mutable.Set[CommitPartitionRequest] = 
{
+                    if (partitionType == PartitionType.MAP) {
+                      commitPartitionRequests.asScala.filterNot { request =>
+                        shuffleCommittedInfo.handledCommitPartitionRequests
+                          .contains(request.partition) && isPartitionInProcess(
+                          request.partition.getId)
+                      }
+                    } else {
+                      commitPartitionRequests.asScala.filterNot { request =>
+                        shuffleCommittedInfo.handledCommitPartitionRequests
+                          .contains(request.partition)
+                      }
+                    }
+                  }
+
                   override def run(): Unit = {
                     val workerToRequests = shuffleCommittedInfo.synchronized {
                       // When running to here, if handleStageEnd got lock 
first and commitFiles,
@@ -105,12 +169,8 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
                         shuffleCommittedInfo.commitPartitionRequests.clear()
                         Map.empty[WorkerInfo, Set[PartitionLocation]]
                       } else {
-                        val batch = new util.HashSet[CommitPartitionRequest]()
-                        
batch.addAll(shuffleCommittedInfo.commitPartitionRequests)
-                        val currentBatch = batch.asScala.filterNot { request =>
-                          shuffleCommittedInfo.handledCommitPartitionRequests
-                            .contains(request.partition)
-                        }
+                        val currentBatch =
+                          
getUnCommitPartitionRequests(shuffleCommittedInfo.commitPartitionRequests)
                         shuffleCommittedInfo.commitPartitionRequests.clear()
                         currentBatch.foreach { commitPartitionRequest =>
                           shuffleCommittedInfo.handledCommitPartitionRequests
@@ -131,8 +191,7 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
                               Seq(request.partition)
                             }
                           }.groupBy(_.getWorker)
-                          shuffleCommittedInfo.inFlightCommitRequest.addAndGet(
-                            workerToRequests.size)
+                          incrementInflightNum(workerToRequests)
                           workerToRequests
                         } else {
                           Map.empty[WorkerInfo, Set[PartitionLocation]]
@@ -180,7 +239,7 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
                         }
                         
lifecycleManager.recordWorkerFailure(commitFilesFailedWorkers)
                       } finally {
-                        
shuffleCommittedInfo.inFlightCommitRequest.addAndGet(-workerToRequests.size)
+                        decrementInflightNum(workerToRequests)
                       }
                     }
                   }
@@ -204,8 +263,8 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
     committedPartitionInfo.put(
       shuffleId,
       ShuffleCommittedInfo(
-        new util.ArrayList[String](),
-        new util.ArrayList[String](),
+        new ConcurrentHashMap[Int, util.List[String]](),
+        new ConcurrentHashMap[Int, util.List[String]](),
         new ConcurrentHashMap[String, WorkerInfo](),
         new ConcurrentHashMap[String, WorkerInfo](),
         new ConcurrentHashMap[String, StorageInfo](),
@@ -214,11 +273,14 @@ class CommitManager(appId: String, val conf: 
CelebornConf, lifecycleManager: Lif
         new LongAdder,
         new util.HashSet[CommitPartitionRequest](),
         new util.HashSet[PartitionLocation](),
-        new AtomicInteger()))
+        new AtomicInteger(),
+        new ConcurrentHashMap[Int, AtomicInteger]()))
   }
 
   def removeExpiredShuffle(shuffleId: Int): Unit = {
     committedPartitionInfo.remove(shuffleId)
+    inProcessStageEndShuffleSet.remove(shuffleId)
+    inProcessMapPartitionEndIds.remove(shuffleId)
     dataLostShuffleSet.remove(shuffleId)
     stageEndShuffleSet.remove(shuffleId)
   }
@@ -304,8 +366,21 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
 
     shuffleCommittedInfo.synchronized {
       // record committed partitionIds
-      shuffleCommittedInfo.committedMasterIds.addAll(res.committedMasterIds)
-      shuffleCommittedInfo.committedSlaveIds.addAll(res.committedSlaveIds)
+      res.committedMasterIds.asScala.foreach({
+        case commitMasterId =>
+          val partitionUniqueIdList = 
shuffleCommittedInfo.committedMasterIds.computeIfAbsent(
+            Utils.splitPartitionLocationUniqueId(commitMasterId)._1,
+            (k: Int) => new util.ArrayList[String]())
+          partitionUniqueIdList.add(commitMasterId)
+      })
+
+      res.committedSlaveIds.asScala.foreach({
+        case commitSlaveId =>
+          val partitionUniqueIdList = 
shuffleCommittedInfo.committedSlaveIds.computeIfAbsent(
+            Utils.splitPartitionLocationUniqueId(commitSlaveId)._1,
+            (k: Int) => new util.ArrayList[String]())
+          partitionUniqueIdList.add(commitSlaveId)
+      })
 
       // record committed partitions storage hint and disk hint
       
shuffleCommittedInfo.committedMasterStorageInfos.putAll(res.committedMasterStorageInfos)
@@ -359,7 +434,7 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
   def finalCommit(
       applicationId: String,
       shuffleId: Int,
-      fileGroups: Array[Array[PartitionLocation]]): Unit = {
+      fileGroups: ConcurrentHashMap[Integer, util.Set[PartitionLocation]]): 
Unit = {
     if (stageEndShuffleSet.contains(shuffleId)) {
       logInfo(s"[handleStageEnd] Shuffle $shuffleId already ended!")
       return
@@ -372,10 +447,31 @@ class CommitManager(appId: String, val conf: 
CelebornConf, lifecycleManager: Lif
       inProcessStageEndShuffleSet.add(shuffleId)
     }
     // ask allLocations workers holding partitions to commit files
+    val allocatedWorkers = 
lifecycleManager.shuffleAllocatedWorkers.get(shuffleId)
+    val dataLost = handleCommitFiles(applicationId, shuffleId, 
allocatedWorkers, None, fileGroups)
+
+    // reply
+    if (!dataLost) {
+      logInfo(s"Succeed to handle stageEnd for $shuffleId.")
+      // record in stageEndShuffleSet
+      stageEndShuffleSet.add(shuffleId)
+    } else {
+      logError(s"Failed to handle stageEnd for $shuffleId, lost file!")
+      dataLostShuffleSet.add(shuffleId)
+      // record in stageEndShuffleSet
+      stageEndShuffleSet.add(shuffleId)
+    }
+    inProcessStageEndShuffleSet.remove(shuffleId)
+  }
+
+  private def handleCommitFiles(
+      applicationId: String,
+      shuffleId: Int,
+      allocatedWorkers: util.Map[WorkerInfo, PartitionLocationInfo],
+      partitionIdOpt: Option[Int] = None,
+      fileGroups: ConcurrentHashMap[Integer, util.Set[PartitionLocation]]): 
Boolean = {
     val masterPartMap = new ConcurrentHashMap[String, PartitionLocation]
     val slavePartMap = new ConcurrentHashMap[String, PartitionLocation]
-
-    val allocatedWorkers = 
lifecycleManager.shuffleAllocatedWorkers.get(shuffleId)
     val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
     val commitFilesFailedWorkers = new ConcurrentHashMap[WorkerInfo, 
(StatusCode, Long)]()
     val commitFileStartTime = System.nanoTime()
@@ -386,8 +482,9 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
       "CommitFiles",
       parallelism) { case (worker, partitionLocationInfo) =>
       if (partitionLocationInfo.containsShuffle(shuffleId.toString)) {
-        val masterParts = 
partitionLocationInfo.getAllMasterLocations(shuffleId.toString)
-        val slaveParts = 
partitionLocationInfo.getAllSlaveLocations(shuffleId.toString)
+        val masterParts =
+          partitionLocationInfo.getMasterLocations(shuffleId.toString, 
partitionIdOpt)
+        val slaveParts = 
partitionLocationInfo.getSlaveLocations(shuffleId.toString, partitionIdOpt)
         masterParts.asScala.foreach { p =>
           val partition = new PartitionLocation(p)
           partition.setFetchPort(worker.fetchPort)
@@ -421,59 +518,16 @@ class CommitManager(appId: String, val conf: 
CelebornConf, lifecycleManager: Lif
           commitFilesFailedWorkers)
       }
     }
+    lifecycleManager.recordWorkerFailure(commitFilesFailedWorkers)
+    // check all inflight request complete, for map partition, it's for single 
partitionId
+    waitInflightRequestComplete(shuffleId, shuffleCommittedInfo, 
partitionIdOpt)
 
-    def hasCommitFailedIds: Boolean = {
-      val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId)
-      if (!pushReplicateEnabled && 
shuffleCommittedInfo.failedMasterPartitionIds.size() != 0) {
-        val msg =
-          shuffleCommittedInfo.failedMasterPartitionIds.asScala.map {
-            case (partitionId, workerInfo) =>
-              s"Lost partition $partitionId in worker 
[${workerInfo.readableAddress()}]"
-          }.mkString("\n")
-        logError(
-          s"""
-             |For shuffle $shuffleKey partition data lost:
-             |$msg
-             |""".stripMargin)
-        true
-      } else {
-        val failedBothPartitionIdsToWorker =
-          shuffleCommittedInfo.failedMasterPartitionIds.asScala.flatMap {
-            case (partitionId, worker) =>
-              if 
(shuffleCommittedInfo.failedSlavePartitionIds.contains(partitionId)) {
-                Some(partitionId -> (worker, 
shuffleCommittedInfo.failedSlavePartitionIds.get(
-                  partitionId)))
-              } else {
-                None
-              }
-          }
-        if (failedBothPartitionIdsToWorker.nonEmpty) {
-          val msg = failedBothPartitionIdsToWorker.map {
-            case (partitionId, (masterWorker, slaveWorker)) =>
-              s"Lost partition $partitionId " +
-                s"in master worker [${masterWorker.readableAddress()}] and 
slave worker [$slaveWorker]"
-          }.mkString("\n")
-          logError(
-            s"""
-               |For shuffle $shuffleKey partition data lost:
-               |$msg
-               |""".stripMargin)
-          true
-        } else {
-          false
-        }
-      }
-    }
-
-    while (shuffleCommittedInfo.inFlightCommitRequest.get() > 0) {
-      Thread.sleep(1000)
-    }
-
-    val dataLost = hasCommitFailedIds
+    // check all data lost or not, for map partition, it's for single 
partitionId
+    val dataLost = checkDataLost(applicationId, shuffleId, partitionIdOpt)
 
     if (!dataLost) {
       val committedPartitions = new util.HashMap[String, PartitionLocation]
-      shuffleCommittedInfo.committedMasterIds.asScala.foreach { id =>
+      getPartitionUniqueIds(shuffleCommittedInfo.committedMasterIds, 
partitionIdOpt).foreach { id =>
         if (shuffleCommittedInfo.committedMasterStorageInfos.get(id) == null) {
           logDebug(s"$applicationId-$shuffleId $id storage hint was not 
returned")
         } else {
@@ -484,7 +538,7 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
         }
       }
 
-      shuffleCommittedInfo.committedSlaveIds.asScala.foreach { id =>
+      getPartitionUniqueIds(shuffleCommittedInfo.committedSlaveIds, 
partitionIdOpt).foreach { id =>
         val slavePartition = slavePartMap.get(id)
         if (shuffleCommittedInfo.committedSlaveStorageInfos.get(id) == null) {
           logDebug(s"$applicationId-$shuffleId $id storage hint was not 
returned")
@@ -503,14 +557,11 @@ class CommitManager(appId: String, val conf: 
CelebornConf, lifecycleManager: Lif
         }
       }
 
-      val sets = Array.fill(fileGroups.length)(new 
util.HashSet[PartitionLocation]())
       committedPartitions.values().asScala.foreach { partition =>
-        sets(partition.getId).add(partition)
-      }
-      var i = 0
-      while (i < fileGroups.length) {
-        fileGroups(i) = sets(i).toArray(new Array[PartitionLocation](0))
-        i += 1
+        val partitionLocations = fileGroups.computeIfAbsent(
+          partition.getId,
+          (k: Integer) => new util.HashSet[PartitionLocation]())
+        partitionLocations.add(partition)
       }
 
       logInfo(s"Shuffle $shuffleId " +
@@ -518,25 +569,130 @@ class CommitManager(appId: String, val conf: 
CelebornConf, lifecycleManager: Lif
         s"using ${(System.nanoTime() - commitFileStartTime) / 1000000} ms")
     }
 
-    // reply
-    if (!dataLost) {
-      logInfo(s"Succeed to handle stageEnd for $shuffleId.")
-      // record in stageEndShuffleSet
-      stageEndShuffleSet.add(shuffleId)
+    dataLost
+  }
+
+  private def getPartitionIds(
+      partitionIds: ConcurrentHashMap[String, WorkerInfo],
+      partitionIdOpt: Option[Int]): util.Map[String, WorkerInfo] = {
+    partitionIdOpt match {
+      case Some(partitionId) => partitionIds.asScala.filter(p =>
+          Utils.splitPartitionLocationUniqueId(p._1)._1 ==
+            partitionId).asJava
+      case None => partitionIds
+    }
+  }
+
+  private def waitInflightRequestComplete(
+      shuffleId: Int,
+      shuffleCommittedInfo: ShuffleCommittedInfo,
+      partitionIdOpt: Option[Int]): Unit = {
+    lifecycleManager.getPartitionType(shuffleId) match {
+      case PartitionType.REDUCE =>
+        while (shuffleCommittedInfo.allInFlightCommitRequestNum.get() > 0) {
+          Thread.sleep(1000)
+        }
+      case PartitionType.MAP => partitionIdOpt match {
+          case Some(partitionId) =>
+            if 
(shuffleCommittedInfo.partitionInFlightCommitRequestNum.containsKey(partitionId))
 {
+              while 
(shuffleCommittedInfo.partitionInFlightCommitRequestNum.get(
+                  partitionId).get() > 0) {
+                Thread.sleep(1000)
+              }
+            }
+        }
+    }
+  }
+
+  private def checkDataLost(
+      applicationId: String,
+      shuffleId: Int,
+      partitionIdOpt: Option[Int]): Boolean = {
+    val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
+    val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId)
+    val masterPartitionUniqueIdMap =
+      getPartitionIds(shuffleCommittedInfo.failedMasterPartitionIds, 
partitionIdOpt)
+    if (!pushReplicateEnabled && masterPartitionUniqueIdMap.size() != 0) {
+      val msg =
+        masterPartitionUniqueIdMap.asScala.map {
+          case (partitionUniqueId, workerInfo) =>
+            s"Lost partition $partitionUniqueId in worker 
[${workerInfo.readableAddress()}]"
+        }.mkString("\n")
+      logError(
+        s"""
+           |For shuffle $shuffleKey partition data lost:
+           |$msg
+           |""".stripMargin)
+      true
     } else {
-      logError(s"Failed to handle stageEnd for $shuffleId, lost file!")
-      dataLostShuffleSet.add(shuffleId)
-      // record in stageEndShuffleSet
-      stageEndShuffleSet.add(shuffleId)
+      val slavePartitionUniqueIdMap =
+        getPartitionIds(shuffleCommittedInfo.failedSlavePartitionIds, 
partitionIdOpt)
+      val failedBothPartitionIdsToWorker = 
masterPartitionUniqueIdMap.asScala.flatMap {
+        case (partitionUniqueId, worker) =>
+          if (slavePartitionUniqueIdMap.asScala.contains(partitionUniqueId)) {
+            Some(partitionUniqueId -> (worker, 
slavePartitionUniqueIdMap.get(partitionUniqueId)))
+          } else {
+            None
+          }
+      }
+      if (failedBothPartitionIdsToWorker.nonEmpty) {
+        val msg = failedBothPartitionIdsToWorker.map {
+          case (partitionUniqueId, (masterWorker, slaveWorker)) =>
+            s"Lost partition $partitionUniqueId " +
+              s"in master worker [${masterWorker.readableAddress()}] and slave 
worker [$slaveWorker]"
+        }.mkString("\n")
+        logError(
+          s"""
+             |For shuffle $shuffleKey partition data lost:
+             |$msg
+             |""".stripMargin)
+        true
+      } else {
+        false
+      }
     }
-    inProcessStageEndShuffleSet.remove(shuffleId)
-    lifecycleManager.recordWorkerFailure(commitFilesFailedWorkers)
   }
 
-  def removeExpiredShuffle(shuffleId: String): Unit = {
-    stageEndShuffleSet.remove(shuffleId)
-    dataLostShuffleSet.remove(shuffleId)
-    committedPartitionInfo.remove(shuffleId)
+  private def getPartitionUniqueIds(
+      ids: ConcurrentHashMap[Int, util.List[String]],
+      partitionIdOpt: Option[Int]): Iterable[String] = {
+    partitionIdOpt match {
+      case Some(partitionId) => ids.asScala.filter(_._1 == 
partitionId).flatMap(_._2.asScala)
+      case None => ids.asScala.flatMap(_._2.asScala)
+    }
+  }
+
+  def finalPartitionCommit(
+      applicationId: String,
+      shuffleId: Int,
+      fileGroups: ConcurrentHashMap[Integer, util.Set[PartitionLocation]],
+      partitionId: Int): Boolean = {
+    val inProcessingPartitionIds =
+      inProcessMapPartitionEndIds.computeIfAbsent(shuffleId, (k: Int) => new 
util.HashSet[Int]())
+    inProcessingPartitionIds.add(partitionId)
+
+    val allocatedWorkers =
+      lifecycleManager.shuffleAllocatedWorkers.get(shuffleId).asScala.filter(p 
=>
+        p._2.containsRelatedShuffleOrPartition(shuffleId.toString, 
Option(partitionId))).asJava
+
+    var dataCommitSuccess = true
+    if (!allocatedWorkers.isEmpty) {
+      dataCommitSuccess =
+        !handleCommitFiles(
+          applicationId,
+          shuffleId,
+          allocatedWorkers,
+          Option(partitionId),
+          fileGroups)
+    }
+
+    // release resources and clear worker info
+    allocatedWorkers.asScala.foreach { case (_, partitionLocationInfo) =>
+      partitionLocationInfo.removeAllRelatedPartitions(shuffleId.toString, 
Option(partitionId))
+    }
+    inProcessingPartitionIds.remove(partitionId)
+
+    dataCommitSuccess
   }
 
   def commitMetrics(): (Long, Long) = (totalWritten.sumThenReset(), 
fileCount.sumThenReset())
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 0ecce2ce..287ecd30 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -20,8 +20,7 @@ package org.apache.celeborn.client
 import java.nio.ByteBuffer
 import java.util
 import java.util.{function, List => JList}
-import java.util.concurrent.{Callable, ConcurrentHashMap, 
ScheduledExecutorService, ScheduledFuture, TimeUnit}
-import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, LongAdder}
+import java.util.concurrent.{Callable, ConcurrentHashMap, ScheduledFuture, 
TimeUnit}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -29,7 +28,6 @@ import scala.util.Random
 
 import com.google.common.annotations.VisibleForTesting
 import com.google.common.cache.{Cache, CacheBuilder}
-import org.roaringbitmap.RoaringBitmap
 
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.haclient.RssHARetryClient
@@ -43,6 +41,7 @@ import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.rpc._
 import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, 
RemoteNettyRpcCallContext}
 import org.apache.celeborn.common.util.{PbSerDeUtils, ThreadUtils, Utils}
+import org.apache.celeborn.common.util.FunctionConverter._
 
 class LifecycleManager(appId: String, val conf: CelebornConf) extends 
RpcEndpoint with Logging {
 
@@ -66,7 +65,7 @@ class LifecycleManager(appId: String, val conf: CelebornConf) 
extends RpcEndpoin
   val registeredShuffle = ConcurrentHashMap.newKeySet[Int]()
   val shuffleMapperAttempts = new ConcurrentHashMap[Int, Array[Int]]()
   private val reducerFileGroupsMap =
-    new ConcurrentHashMap[Int, Array[Array[PartitionLocation]]]()
+    new ConcurrentHashMap[Int, ConcurrentHashMap[Integer, 
util.Set[PartitionLocation]]]()
   private val shuffleTaskInfo = new ShuffleTaskInfo()
   // maintain each shuffle's map relation of WorkerInfo and partition location
   val shuffleAllocatedWorkers =
@@ -310,10 +309,16 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
         epoch,
         oldPartition)
 
-    case MapperEnd(applicationId, shuffleId, mapId, attemptId, numMappers) =>
-      logTrace(s"Received MapperEnd request, " +
-        s"${Utils.makeMapKey(applicationId, shuffleId, mapId, attemptId)}.")
-      handleMapperEnd(context, applicationId, shuffleId, mapId, attemptId, 
numMappers)
+    case MapperEnd(applicationId, shuffleId, mapId, attemptId, numMappers, 
partitionId) =>
+      logTrace(s"Received MapperEnd TaskEnd request, " +
+        s"${Utils.makeMapKey(applicationId, shuffleId, mapId, attemptId)}")
+      val partitionType = getPartitionType(shuffleId)
+      partitionType match {
+        case PartitionType.REDUCE =>
+          handleMapperEnd(context, applicationId, shuffleId, mapId, attemptId, 
numMappers)
+        case PartitionType.MAP =>
+          handleMapPartitionEnd(context, applicationId, shuffleId, mapId, 
attemptId, partitionId)
+      }
 
     case GetReducerFileGroup(applicationId: String, shuffleId: Int) =>
       logDebug(s"Received GetShuffleFileGroup request," +
@@ -483,8 +488,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
           }
         }
       }
-
-      reducerFileGroupsMap.put(shuffleId, new 
Array[Array[PartitionLocation]](numReducers))
+      reducerFileGroupsMap.put(shuffleId, new ConcurrentHashMap())
 
       // Fifth, reply the allocated partition location to ShuffleClient.
       logInfo(s"Handle RegisterShuffle Success for $shuffleId.")
@@ -551,13 +555,15 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       return
     }
 
-    // If shuffle registered and corresponding map finished, reply MapEnd and 
return.
-    if (shuffleMapperAttempts.containsKey(shuffleId)
-      && shuffleMapperAttempts.get(shuffleId)(mapId) != -1) {
-      logWarning(s"[handleRevive] Mapper ended, mapId $mapId, current 
attemptId $attemptId, " +
-        s"ended attemptId ${shuffleMapperAttempts.get(shuffleId)(mapId)}, 
shuffleId $shuffleId.")
-      context.reply(ChangeLocationResponse(StatusCode.MAP_ENDED, None))
-      return
+    // If shuffle registered and corresponding map finished, reply MapEnd and 
return. Only for reduce partition type
+    if (getPartitionType(shuffleId) == PartitionType.REDUCE) {
+      if (shuffleMapperAttempts.containsKey(shuffleId) && 
shuffleMapperAttempts.get(shuffleId)(
+          mapId) != -1) {
+        logWarning(s"[handleRevive] Mapper ended, mapId $mapId, current 
attemptId $attemptId, " +
+          s"ended attemptId ${shuffleMapperAttempts.get(shuffleId)(mapId)}, 
shuffleId $shuffleId.")
+        context.reply(ChangeLocationResponse(StatusCode.MAP_ENDED, None))
+        return
+      }
     }
 
     logWarning(s"Do Revive for shuffle ${Utils.makeShuffleKey(applicationId, 
shuffleId)}, " +
@@ -621,28 +627,39 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       shuffleId: Int): Unit = {
     var timeout = stageEndTimeout
     val delta = 100
-    while (!commitManager.stageEndShuffleSet.contains(shuffleId)) {
-      Thread.sleep(delta)
-      if (timeout <= 0) {
-        logError(s"[handleGetReducerFileGroup] Wait for handleStageEnd 
Timeout! $shuffleId.")
-        context.reply(
-          GetReducerFileGroupResponse(StatusCode.STAGE_END_TIME_OUT, 
Array.empty, Array.empty))
-        return
+    // reduce partition need wait stage end. While map partition Would commit 
every partition synchronously.
+    if (getPartitionType(shuffleId) == PartitionType.REDUCE) {
+      while (!commitManager.stageEndShuffleSet.contains(shuffleId)) {
+        Thread.sleep(delta)
+        if (timeout <= 0) {
+          logError(s"[handleGetReducerFileGroup] Wait for handleStageEnd 
Timeout! $shuffleId.")
+          context.reply(
+            GetReducerFileGroupResponse(
+              StatusCode.STAGE_END_TIME_OUT,
+              new ConcurrentHashMap(),
+              Array.empty))
+          return
+        }
+        timeout = timeout - delta
       }
-      timeout = timeout - delta
+      logDebug("[handleGetReducerFileGroup] Wait for handleStageEnd complete 
cost" +
+        s" ${stageEndTimeout - timeout}ms")
     }
-    logDebug("[handleGetReducerFileGroup] Wait for handleStageEnd complete 
cost" +
-      s" ${stageEndTimeout - timeout}ms")
 
     if (commitManager.dataLostShuffleSet.contains(shuffleId)) {
       context.reply(
-        GetReducerFileGroupResponse(StatusCode.SHUFFLE_DATA_LOST, Array.empty, 
Array.empty))
+        GetReducerFileGroupResponse(
+          StatusCode.SHUFFLE_DATA_LOST,
+          new ConcurrentHashMap(),
+          Array.empty))
     } else {
       if (context.isInstanceOf[LocalNettyRpcCallContext]) {
         // This branch is for the UTs
         context.reply(GetReducerFileGroupResponse(
           StatusCode.SUCCESS,
-          reducerFileGroupsMap.getOrDefault(shuffleId, Array.empty),
+          reducerFileGroupsMap.getOrDefault(
+            shuffleId,
+            new ConcurrentHashMap()),
           shuffleMapperAttempts.getOrDefault(shuffleId, Array.empty)))
       } else {
         val cachedMsg = getReducerFileGroupRpcCache.get(
@@ -651,7 +668,9 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
             override def call(): ByteBuffer = {
               val returnedMsg = GetReducerFileGroupResponse(
                 StatusCode.SUCCESS,
-                reducerFileGroupsMap.getOrDefault(shuffleId, Array.empty),
+                reducerFileGroupsMap.getOrDefault(
+                  shuffleId,
+                  new ConcurrentHashMap()),
                 shuffleMapperAttempts.getOrDefault(shuffleId, Array.empty))
               
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
             }
@@ -684,24 +703,55 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
       ReleaseSlots(applicationId, shuffleId, List.empty.asJava, 
List.empty.asJava))
   }
 
+  private def handleMapPartitionEnd(
+      context: RpcCallContext,
+      applicationId: String,
+      shuffleId: Int,
+      mapId: Int,
+      attemptId: Int,
+      partitionId: Int): Unit = {
+    def reply(result: Boolean): Unit = {
+      val message =
+        s"to handle MapPartitionEnd for ${Utils.makeMapKey(appId, shuffleId, 
mapId, attemptId)}, " +
+          s"$partitionId.";
+      result match {
+        case true => // if already committed by another try
+          logDebug(s"Succeed $message")
+          context.reply(MapperEndResponse(StatusCode.SUCCESS))
+        case false =>
+          logError(s"Failed $message")
+          context.reply(MapperEndResponse(StatusCode.SHUFFLE_DATA_LOST))
+      }
+    }
+
+    val dataCommitSuccess = commitManager.finalPartitionCommit(
+      applicationId,
+      shuffleId,
+      reducerFileGroupsMap.get(shuffleId),
+      partitionId)
+    reply(dataCommitSuccess)
+  }
+
   private def handleUnregisterShuffle(
       appId: String,
       shuffleId: Int): Unit = {
-    // if StageEnd has not been handled, trigger StageEnd
-    if (!commitManager.stageEndShuffleSet.contains(shuffleId)) {
-      logInfo(s"Call StageEnd before Unregister Shuffle $shuffleId.")
-      handleStageEnd(appId, shuffleId)
-      var timeout = stageEndTimeout
-      val delta = 100
-      while (!commitManager.stageEndShuffleSet.contains(shuffleId) && timeout 
> 0) {
-        Thread.sleep(delta)
-        timeout = timeout - delta
-      }
-      if (timeout <= 0) {
-        logError(s"StageEnd Timeout! $shuffleId.")
-      } else {
-        logInfo("[handleUnregisterShuffle] Wait for handleStageEnd complete 
cost" +
-          s" ${stageEndTimeout - timeout}ms")
+    if (getPartitionType(shuffleId) == PartitionType.REDUCE) {
+      // if StageEnd has not been handled, trigger StageEnd
+      if (!commitManager.stageEndShuffleSet.contains(shuffleId)) {
+        logInfo(s"Call StageEnd before Unregister Shuffle $shuffleId.")
+        handleStageEnd(appId, shuffleId)
+        var timeout = stageEndTimeout
+        val delta = 100
+        while (!commitManager.stageEndShuffleSet.contains(shuffleId) && 
timeout > 0) {
+          Thread.sleep(delta)
+          timeout = timeout - delta
+        }
+        if (timeout <= 0) {
+          logError(s"StageEnd Timeout! $shuffleId.")
+        } else {
+          logInfo("[handleUnregisterShuffle] Wait for handleStageEnd complete 
cost" +
+            s" ${stageEndTimeout - timeout}ms")
+        }
       }
     }
 
diff --git 
a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java 
b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
index 2fb60718..95470f6c 100644
--- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -91,6 +91,16 @@ public class DummyShuffleClient extends ShuffleClient {
   public void mapperEnd(
       String applicationId, int shuffleId, int mapId, int attemptId, int 
numMappers) {}
 
+  @Override
+  public void mapPartitionMapperEnd(
+      String applicationId,
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int numMappers,
+      int partitionId)
+      throws IOException {}
+
   @Override
   public void cleanup(String applicationId, int shuffleId, int mapId, int 
attemptId) {}
 
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index 2cf93798..ae4f3816 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -232,6 +232,7 @@ message PbMapperEnd {
   int32 mapId = 3;
   int32 attemptId = 4;
   int32 numMappers = 5;
+  int32 partitionId = 6;
 }
 
 message PbMapperEndResponse {
@@ -245,7 +246,10 @@ message PbGetReducerFileGroup {
 
 message PbGetReducerFileGroupResponse {
   int32 status = 1;
-  repeated PbFileGroup fileGroup = 2;
+  // PartitionId -> Partition FileGroup
+  map<int32, PbFileGroup> fileGroups = 2;
+
+  // only reduce partition mode need know valid attempts
   repeated int32 attempts = 3;
 }
 
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala
 
b/common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala
index 4f29c63a..78550907 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala
@@ -46,6 +46,24 @@ class PartitionLocationInfo extends Logging {
     slavePartitionLocations.containsKey(shuffleKey)
   }
 
+  def containsRelatedShuffleOrPartition(shuffleKey: String, partitionIdOpt: 
Option[Int]): Boolean =
+    this.synchronized {
+      partitionIdOpt match {
+        case Some(partitionId) =>
+          containsPartition(shuffleKey, partitionId)
+        case None => containsShuffle(shuffleKey)
+      }
+    }
+
+  private def containsPartition(shuffleKey: String, partitionId: Int): Boolean 
= this
+    .synchronized {
+      val contain = masterPartitionLocations.containsKey(
+        shuffleKey) && 
masterPartitionLocations.get(shuffleKey).containsKey(partitionId)
+      contain || (slavePartitionLocations.containsKey(shuffleKey) && 
slavePartitionLocations.get(
+        shuffleKey)
+        .containsKey(partitionId))
+    }
+
   def addMasterPartition(shuffleKey: String, location: PartitionLocation): Int 
= {
     addPartition(shuffleKey, location, masterPartitionLocations)
   }
@@ -75,26 +93,44 @@ class PartitionLocationInfo extends Logging {
   }
 
   def getAllMasterLocations(shuffleKey: String): util.List[PartitionLocation] 
= this.synchronized {
-    if (masterPartitionLocations.containsKey(shuffleKey)) {
-      masterPartitionLocations.get(shuffleKey)
-        .values()
-        .asScala
-        .flatMap(_.asScala)
-        .toList
-        .asJava
-    } else {
-      new util.ArrayList[PartitionLocation]()
-    }
+    getMasterLocations(shuffleKey)
   }
 
   def getAllSlaveLocations(shuffleKey: String): util.List[PartitionLocation] = 
this.synchronized {
-    if (slavePartitionLocations.containsKey(shuffleKey)) {
-      slavePartitionLocations.get(shuffleKey)
-        .values()
-        .asScala
-        .flatMap(_.asScala)
-        .toList
-        .asJava
+    getSlaveLocations(shuffleKey)
+  }
+
+  def getMasterLocations(
+      shuffleKey: String,
+      partitionIdOpt: Option[Int] = None): util.List[PartitionLocation] = {
+    getLocations(shuffleKey, masterPartitionLocations, partitionIdOpt)
+  }
+
+  def getSlaveLocations(
+      shuffleKey: String,
+      partitionIdOpt: Option[Int] = None): util.List[PartitionLocation] = {
+    getLocations(shuffleKey, slavePartitionLocations, partitionIdOpt)
+  }
+
+  private def getLocations(
+      shuffleKey: String,
+      partitionInfo: PartitionInfo,
+      partitionIdOpt: Option[Int] = None): util.List[PartitionLocation] = 
this.synchronized {
+    if (partitionInfo.containsKey(shuffleKey)) {
+      partitionIdOpt match {
+        case Some(partitionId) => partitionInfo.get(shuffleKey)
+            .values()
+            .asScala
+            .flatMap(_.asScala)
+            .filter(_.getId == partitionId)
+            .toList.asJava
+        case None =>
+          partitionInfo.get(shuffleKey)
+            .values()
+            .asScala
+            .flatMap(_.asScala)
+            .toList.asJava
+      }
     } else {
       new util.ArrayList[PartitionLocation]()
     }
@@ -201,6 +237,24 @@ class PartitionLocationInfo extends Logging {
     }
   }
 
+  def removeAllRelatedPartitions(
+      shuffleKey: String,
+      partitionIdOpt: Option[Int]): Unit = this
+    .synchronized {
+      partitionIdOpt match {
+        case Some(partitionId) =>
+          if (masterPartitionLocations.containsKey(shuffleKey)) {
+            masterPartitionLocations.get(shuffleKey).remove(partitionId)
+          }
+          if (slavePartitionLocations.containsKey(shuffleKey)) {
+            slavePartitionLocations.get(shuffleKey).remove(partitionId)
+          }
+        case None =>
+          removeMasterPartitions(shuffleKey)
+          removeSlavePartitions(shuffleKey)
+      }
+    }
+
   /**
    * @param shuffleKey
    * @param uniqueIds
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 9697aabb..d8255005 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -242,7 +242,8 @@ object ControlMessages extends Logging {
       shuffleId: Int,
       mapId: Int,
       attemptId: Int,
-      numMappers: Int)
+      numMappers: Int,
+      partitionId: Int)
     extends MasterMessage
 
   case class MapperEndResponse(status: StatusCode) extends MasterMessage
@@ -253,7 +254,7 @@ object ControlMessages extends Logging {
   // Path can't be serialized
   case class GetReducerFileGroupResponse(
       status: StatusCode,
-      fileGroup: Array[Array[PartitionLocation]],
+      fileGroup: util.Map[Integer, util.Set[PartitionLocation]],
       attempts: Array[Int])
     extends MasterMessage
 
@@ -521,13 +522,14 @@ object ControlMessages extends Logging {
     case pb: PbChangeLocationResponse =>
       new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, 
pb.toByteArray)
 
-    case MapperEnd(applicationId, shuffleId, mapId, attemptId, numMappers) =>
+    case MapperEnd(applicationId, shuffleId, mapId, attemptId, numMappers, 
partitionId) =>
       val payload = PbMapperEnd.newBuilder()
         .setApplicationId(applicationId)
         .setShuffleId(shuffleId)
         .setMapId(mapId)
         .setAttemptId(attemptId)
         .setNumMappers(numMappers)
+        .setPartitionId(partitionId)
         .build().toByteArray
       new TransportMessage(MessageType.MAPPER_END, payload)
 
@@ -547,13 +549,13 @@ object ControlMessages extends Logging {
       val builder = PbGetReducerFileGroupResponse
         .newBuilder()
         .setStatus(status.getValue)
-      builder.addAllFileGroup(
-        fileGroup.map { arr =>
-          PbFileGroup.newBuilder().addAllLocations(arr
-            .map(PbSerDeUtils.toPbPartitionLocation).toIterable.asJava).build()
-        }
-          .toIterable
-          .asJava)
+      builder.putAllFileGroups(
+        fileGroup.asScala.map { case (partitionId, fileGroup) =>
+          (
+            partitionId,
+            
PbFileGroup.newBuilder().addAllLocations(fileGroup.asScala.map(PbSerDeUtils
+              .toPbPartitionLocation).toList.asJava).build())
+        }.asJava)
       builder.addAllAttempts(attempts.map(new Integer(_)).toIterable.asJava)
       val payload = builder.build().toByteArray
       new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE, 
payload)
@@ -875,7 +877,8 @@ object ControlMessages extends Logging {
           pbMapperEnd.getShuffleId,
           pbMapperEnd.getMapId,
           pbMapperEnd.getAttemptId,
-          pbMapperEnd.getNumMappers)
+          pbMapperEnd.getNumMappers,
+          pbMapperEnd.getPartitionId)
 
       case MAPPER_END_RESPONSE =>
         val pbMapperEndResponse = 
PbMapperEndResponse.parseFrom(message.getPayload)
@@ -890,9 +893,14 @@ object ControlMessages extends Logging {
       case GET_REDUCER_FILE_GROUP_RESPONSE =>
         val pbGetReducerFileGroupResponse = PbGetReducerFileGroupResponse
           .parseFrom(message.getPayload)
-        val fileGroup = 
pbGetReducerFileGroupResponse.getFileGroupList.asScala.map { fg =>
-          
fg.getLocationsList.asScala.map(PbSerDeUtils.fromPbPartitionLocation).toArray
-        }.toArray
+        val fileGroup = 
pbGetReducerFileGroupResponse.getFileGroupsMap.asScala.map {
+          case (partitionId, fileGroup) =>
+            (
+              partitionId,
+              fileGroup.getLocationsList.asScala.map(
+                PbSerDeUtils.fromPbPartitionLocation).toSet.asJava)
+        }.asJava
+
         val attempts = 
pbGetReducerFileGroupResponse.getAttemptsList.asScala.map(_.toInt).toArray
         GetReducerFileGroupResponse(
           Utils.toStatusCode(pbGetReducerFileGroupResponse.getStatus),
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/util/FunctionConverter.scala 
b/common/src/main/scala/org/apache/celeborn/common/util/FunctionConverter.scala
new file mode 100644
index 00000000..71c75cf2
--- /dev/null
+++ 
b/common/src/main/scala/org/apache/celeborn/common/util/FunctionConverter.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.util
+
+/**
+ * Implicit conversion for scala(2.11) function to java function
+ */
+object FunctionConverter {
+
+  implicit def scalaFunctionToJava[From, To](function: (From) => To)
+      : java.util.function.Function[From, To] = {
+    new java.util.function.Function[From, To] {
+      override def apply(input: From): To = function(input)
+    }
+  }
+
+}
diff --git 
a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala 
b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
index d631d7dc..70ecc84b 100644
--- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
+++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
@@ -17,7 +17,12 @@
 
 package org.apache.celeborn.common.util
 
+import java.util
+
 import org.apache.celeborn.RssFunSuite
+import org.apache.celeborn.common.protocol.PartitionLocation
+import 
org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse,
 MapperEnd}
+import org.apache.celeborn.common.protocol.message.StatusCode
 
 class UtilsSuite extends RssFunSuite {
 
@@ -92,4 +97,45 @@ class UtilsSuite extends RssFunSuite {
   test("getThreadDump") {
     assert(Utils.getThreadDump().nonEmpty)
   }
+
+  test("MapperEnd class convert with pb") {
+    val mapperEnd = MapperEnd("application1", 1, 1, 1, 2, 1)
+    val mapperEndTrans =
+      
Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd]
+    assert(mapperEnd == mapperEndTrans)
+  }
+
+  test("GetReducerFileGroupResponse class convert with pb") {
+    val fileGroup = new util.HashMap[Integer, util.Set[PartitionLocation]]
+    fileGroup.put(0, partitionLocation(0))
+    fileGroup.put(1, partitionLocation(1))
+    fileGroup.put(2, partitionLocation(2))
+
+    val attempts = Array(0, 0, 1)
+    val response = GetReducerFileGroupResponse(StatusCode.STAGE_ENDED, 
fileGroup, attempts)
+    val responseTrans = 
Utils.fromTransportMessage(Utils.toTransportMessage(response)).asInstanceOf[
+      GetReducerFileGroupResponse]
+
+    assert(response.status == responseTrans.status)
+    assert(response.attempts.deep == responseTrans.attempts.deep)
+    val set =
+      (response.fileGroup.values().toArray diff 
responseTrans.fileGroup.values().toArray).toSet
+    assert(set.size == 0)
+  }
+
+  def partitionLocation(partitionId: Int): util.HashSet[PartitionLocation] = {
+    val partitionSet = new util.HashSet[PartitionLocation]
+    for (i <- 0 until 3) {
+      partitionSet.add(new PartitionLocation(
+        partitionId,
+        i,
+        "host",
+        100,
+        1000,
+        1001,
+        100,
+        PartitionLocation.Mode.MASTER))
+    }
+    partitionSet
+  }
 }
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
index 4e857e17..926210ff 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
@@ -36,6 +36,9 @@ class ShuffleClientSuite extends AnyFunSuite with 
MiniClusterFeature
   val APP = "app-1"
   var shuffleClient: ShuffleClientImpl = _
   var lifecycleManager: LifecycleManager = _
+  val numMappers = 8
+  val mapId = 1
+  val attemptId = 0
 
   override def beforeAll(): Unit = {
     val masterConf = Map(
@@ -54,11 +57,8 @@ class ShuffleClientSuite extends AnyFunSuite with 
MiniClusterFeature
     shuffleClient.setupMetaServiceRef(lifecycleManager.self)
   }
 
-  test(s"test register map partition task with first attemptId") {
+  test(s"test register map partition task") {
     val shuffleId = 1
-    val numMappers = 8
-    val mapId = 1
-    val attemptId = 0
     var location =
       shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, 
mapId, attemptId)
     Assert.assertEquals(location.getId, 
PackedPartitionId.packedPartitionId(mapId, attemptId))
@@ -93,6 +93,39 @@ class ShuffleClientSuite extends AnyFunSuite with 
MiniClusterFeature
     Assert.assertEquals(count, numMappers + 1)
   }
 
+  test(s"test map end & get reducer file group") {
+    val shuffleId = 2
+    shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, 
attemptId)
+    shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId + 
1, attemptId)
+    shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId + 
2, attemptId)
+    shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, 
attemptId + 1)
+    shuffleClient.mapPartitionMapperEnd(APP, shuffleId, numMappers, mapId, 
attemptId, mapId)
+    // retry
+    shuffleClient.mapPartitionMapperEnd(APP, shuffleId, numMappers, mapId, 
attemptId, mapId)
+    // another attempt
+    shuffleClient.mapPartitionMapperEnd(
+      APP,
+      shuffleId,
+      numMappers,
+      mapId,
+      attemptId + 1,
+      PackedPartitionId
+        .packedPartitionId(mapId, attemptId + 1))
+    // another mapper
+    shuffleClient.mapPartitionMapperEnd(APP, shuffleId, numMappers, mapId + 1, 
attemptId, mapId + 1)
+
+    // reduce file group size (for empty partitions)
+    Assert.assertEquals(shuffleClient.getReduceFileGroupsMap.size(), 0)
+
+    // reduce normal empty RssInputStream
+    var stream = shuffleClient.readPartition(APP, shuffleId, 1, 1)
+    Assert.assertEquals(stream.read(), -1)
+
+    // reduce normal null partition for RssInputStream
+    stream = shuffleClient.readPartition(APP, shuffleId, 3, 1)
+    Assert.assertEquals(stream.read(), -1)
+  }
+
   override def afterAll(): Unit = {
     // TODO refactor MiniCluster later
     println("test done")
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
index 932998e7..97c333a3 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
@@ -50,8 +50,16 @@ class HugeDataTest extends AnyFunSuite
   test("celeborn spark integration test - huge data") {
     val sparkConf = new 
SparkConf().setAppName("rss-demo").setMaster("local[4]")
     val ss = SparkSession.builder().config(updateSparkConf(sparkConf, 
false)).getOrCreate()
-    ss.sparkContext.parallelize(1 to 10000, 2)
-      .map { i => (i, Range(1, 10000).mkString(",")) }.groupByKey(16).collect()
+    val value = Range(1, 10000).mkString(",")
+    val tuples = ss.sparkContext.parallelize(1 to 10000, 2)
+      .map { i => (i, value) }.groupByKey(16).collect()
+
+    // verify result
+    assert(tuples.length == 10000)
+    for (elem <- tuples) {
+      assert(elem._2.mkString(",").equals(value))
+    }
+
     ss.stop()
   }
 }

Reply via email to