This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new e2196e93 [CELEBORN-56] [ISSUE-945] handle map partition mapper end
(#1003)
e2196e93 is described below
commit e2196e93830fdd90bdb3861d513942cb48723f0e
Author: Shuang <[email protected]>
AuthorDate: Wed Dec 7 21:09:02 2022 +0800
[CELEBORN-56] [ISSUE-945] handle map partition mapper end (#1003)
---
.../org/apache/celeborn/client/ShuffleClient.java | 12 +-
.../apache/celeborn/client/ShuffleClientImpl.java | 46 ++-
.../org/apache/celeborn/client/CommitManager.scala | 348 +++++++++++++++------
.../apache/celeborn/client/LifecycleManager.scala | 140 ++++++---
.../apache/celeborn/client/DummyShuffleClient.java | 10 +
common/src/main/proto/TransportMessages.proto | 6 +-
.../common/meta/PartitionLocationInfo.scala | 88 +++++-
.../common/protocol/message/ControlMessages.scala | 36 ++-
.../celeborn/common/util/FunctionConverter.scala | 32 ++
.../apache/celeborn/common/util/UtilsSuite.scala | 46 +++
.../celeborn/tests/client/ShuffleClientSuite.scala | 41 ++-
.../apache/celeborn/tests/spark/HugeDataTest.scala | 12 +-
12 files changed, 629 insertions(+), 188 deletions(-)
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index b9fcb9e1..3b76333f 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -149,11 +149,21 @@ public abstract class ShuffleClient implements Cloneable {
public abstract void pushMergedData(String applicationId, int shuffleId, int
mapId, int attemptId)
throws IOException;
- // Report partition locations written by the completed map task
+ // Report partition locations written by the completed map task of
ReducePartition Shuffle Type
public abstract void mapperEnd(
String applicationId, int shuffleId, int mapId, int attemptId, int
numMappers)
throws IOException;
+ // Report partition locations written by the completed map task of
MapPartition Shuffle Type
+ public abstract void mapPartitionMapperEnd(
+ String applicationId,
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int partitionId)
+ throws IOException;
+
// Cleanup states of the map task
public abstract void cleanup(String applicationId, int shuffleId, int mapId,
int attemptId);
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 67a6ee4e..a3d0c6b1 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -113,10 +113,10 @@ public class ShuffleClientImpl extends ShuffleClient {
};
private static class ReduceFileGroups {
- final PartitionLocation[][] partitionGroups;
+ final Map<Integer, Set<PartitionLocation>> partitionGroups;
final int[] mapAttempts;
- ReduceFileGroups(PartitionLocation[][] partitionGroups, int[] mapAttempts)
{
+ ReduceFileGroups(Map<Integer, Set<PartitionLocation>> partitionGroups,
int[] mapAttempts) {
this.partitionGroups = partitionGroups;
this.mapAttempts = mapAttempts;
}
@@ -1028,6 +1028,29 @@ public class ShuffleClientImpl extends ShuffleClient {
public void mapperEnd(
String applicationId, int shuffleId, int mapId, int attemptId, int
numMappers)
throws IOException {
+ mapEndInternal(applicationId, shuffleId, mapId, attemptId, numMappers, -1);
+ }
+
+ @Override
+ public void mapPartitionMapperEnd(
+ String applicationId,
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int partitionId)
+ throws IOException {
+ mapEndInternal(applicationId, shuffleId, mapId, attemptId, numMappers,
partitionId);
+ }
+
+ private void mapEndInternal(
+ String applicationId,
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ Integer partitionId)
+ throws IOException {
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new
PushState(conf));
@@ -1036,7 +1059,7 @@ public class ShuffleClientImpl extends ShuffleClient {
MapperEndResponse response =
driverRssMetaService.askSync(
- new MapperEnd(applicationId, shuffleId, mapId, attemptId,
numMappers),
+ new MapperEnd(applicationId, shuffleId, mapId, attemptId,
numMappers, partitionId),
ClassTag$.MODULE$.apply(MapperEndResponse.class));
if (response.status() != StatusCode.SUCCESS) {
throw new IOException("MapperEnd failed! StatusCode: " +
response.status());
@@ -1119,9 +1142,10 @@ public class ShuffleClientImpl extends ShuffleClient {
if (response.status() == StatusCode.SUCCESS) {
logger.info(
- "Shuffle {} request reducer file group success using
time:{} ms",
+ "Shuffle {} request reducer file group success using
time:{} ms, result partition ids: {}",
shuffleId,
- (System.nanoTime() - getReducerFileGroupStartTime) /
1000_000);
+ (System.nanoTime() - getReducerFileGroupStartTime) /
1000_000,
+ response.fileGroup().keySet());
return new ReduceFileGroups(response.fileGroup(),
response.attempts());
} else if (response.status() == StatusCode.STAGE_END_TIME_OUT)
{
logger.warn(
@@ -1147,15 +1171,16 @@ public class ShuffleClientImpl extends ShuffleClient {
String msg = "Shuffle data lost for shuffle " + shuffleId + " reduce " +
partitionId + "!";
logger.error(msg);
throw new IOException(msg);
- } else if (fileGroups.partitionGroups.length == 0) {
- logger.warn("Shuffle data is empty for shuffle {} reduce {}.",
shuffleId, partitionId);
+ } else if (fileGroups.partitionGroups.size() == 0
+ || !fileGroups.partitionGroups.containsKey(partitionId)) {
+ logger.warn("Shuffle data is empty for shuffle {} partitionId {}.",
shuffleId, partitionId);
return RssInputStream.empty();
} else {
return RssInputStream.create(
conf,
dataClientFactory,
shuffleKey,
- fileGroups.partitionGroups[partitionId],
+ fileGroups.partitionGroups.get(partitionId).toArray(new
PartitionLocation[0]),
fileGroups.mapAttempts,
attemptNumber,
startMapIndex,
@@ -1163,6 +1188,11 @@ public class ShuffleClientImpl extends ShuffleClient {
}
}
+ @VisibleForTesting
+ public Map<Integer, ReduceFileGroups> getReduceFileGroupsMap() {
+ return reduceFileGroupsMap;
+ }
+
@Override
public void shutDown() {
if (null != rpcEnv) {
diff --git
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index 9bd7aeb7..4bb2ffd9 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -29,11 +29,12 @@ import org.roaringbitmap.RoaringBitmap
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
-import org.apache.celeborn.common.protocol.{PartitionLocation, StorageInfo}
+import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType,
StorageInfo}
import
org.apache.celeborn.common.protocol.message.ControlMessages.{CommitFiles,
CommitFilesResponse}
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.rpc.RpcEndpointRef
import org.apache.celeborn.common.util.{ThreadUtils, Utils}
+import org.apache.celeborn.common.util.FunctionConverter._
case class CommitPartitionRequest(
applicationId: String,
@@ -41,8 +42,8 @@ case class CommitPartitionRequest(
partition: PartitionLocation)
case class ShuffleCommittedInfo(
- committedMasterIds: util.List[String],
- committedSlaveIds: util.List[String],
+ committedMasterIds: ConcurrentHashMap[Int, util.List[String]],
+ committedSlaveIds: ConcurrentHashMap[Int, util.List[String]],
failedMasterPartitionIds: ConcurrentHashMap[String, WorkerInfo],
failedSlavePartitionIds: ConcurrentHashMap[String, WorkerInfo],
committedMasterStorageInfos: ConcurrentHashMap[String, StorageInfo],
@@ -51,7 +52,8 @@ case class ShuffleCommittedInfo(
currentShuffleFileCount: LongAdder,
commitPartitionRequests: util.Set[CommitPartitionRequest],
handledCommitPartitionRequests: util.Set[PartitionLocation],
- inFlightCommitRequest: AtomicInteger)
+ allInFlightCommitRequestNum: AtomicInteger,
+ partitionInFlightCommitRequestNum: ConcurrentHashMap[Int, AtomicInteger])
class CommitManager(appId: String, val conf: CelebornConf, lifecycleManager:
LifecycleManager)
extends Logging {
@@ -60,7 +62,8 @@ class CommitManager(appId: String, val conf: CelebornConf,
lifecycleManager: Lif
val dataLostShuffleSet = ConcurrentHashMap.newKeySet[Int]()
val stageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
-
+ // shuffleId -> in processing partitionId set
+ private val inProcessMapPartitionEndIds = new ConcurrentHashMap[Int,
util.Set[Int]]()
private val pushReplicateEnabled = conf.pushReplicateEnabled
private val batchHandleCommitPartitionEnabled =
conf.batchHandleCommitPartitionEnabled
@@ -92,6 +95,67 @@ class CommitManager(appId: String, val conf: CelebornConf,
lifecycleManager: Lif
committedPartitionInfo.asScala.foreach { case (shuffleId,
shuffleCommittedInfo) =>
batchHandleCommitPartitionExecutors.submit {
new Runnable {
+ val partitionType =
lifecycleManager.getPartitionType(shuffleId)
+ def isPartitionInProcess(partitionId: Int): Boolean = {
+ if (inProcessMapPartitionEndIds.containsKey(shuffleId) &&
+
inProcessMapPartitionEndIds.get(shuffleId).contains(partitionId)) {
+ true
+ } else {
+ false
+ }
+ }
+
+ def incrementInflightNum(workerToRequests: Map[
+ WorkerInfo,
+ collection.Set[PartitionLocation]]): Unit = {
+ if (partitionType == PartitionType.MAP) {
+ workerToRequests.foreach {
+ case (_, partitions) =>
+ partitions.groupBy(_.getId).foreach { case (id, _) =>
+ val atomicInteger = shuffleCommittedInfo
+ .partitionInFlightCommitRequestNum
+ .computeIfAbsent(id, (k: Int) => new
AtomicInteger(0))
+ atomicInteger.incrementAndGet()
+ }
+ }
+ }
+ shuffleCommittedInfo.allInFlightCommitRequestNum.addAndGet(
+ workerToRequests.size)
+ }
+
+ def decrementInflightNum(
+ workerToRequests: Map[WorkerInfo,
collection.Set[PartitionLocation]])
+ : Unit = {
+ if (partitionType == PartitionType.MAP) {
+ workerToRequests.foreach {
+ case (_, partitions) =>
+ partitions.groupBy(_.getId).foreach { case (id, _) =>
+
shuffleCommittedInfo.partitionInFlightCommitRequestNum.get(
+ id).decrementAndGet()
+ }
+ }
+ }
+ shuffleCommittedInfo.allInFlightCommitRequestNum.addAndGet(
+ -workerToRequests.size)
+ }
+
+ def getUnCommitPartitionRequests(
+ commitPartitionRequests:
util.Set[CommitPartitionRequest])
+ : scala.collection.mutable.Set[CommitPartitionRequest] =
{
+ if (partitionType == PartitionType.MAP) {
+ commitPartitionRequests.asScala.filterNot { request =>
+ shuffleCommittedInfo.handledCommitPartitionRequests
+ .contains(request.partition) && isPartitionInProcess(
+ request.partition.getId)
+ }
+ } else {
+ commitPartitionRequests.asScala.filterNot { request =>
+ shuffleCommittedInfo.handledCommitPartitionRequests
+ .contains(request.partition)
+ }
+ }
+ }
+
override def run(): Unit = {
val workerToRequests = shuffleCommittedInfo.synchronized {
// When running to here, if handleStageEnd got lock
first and commitFiles,
@@ -105,12 +169,8 @@ class CommitManager(appId: String, val conf: CelebornConf,
lifecycleManager: Lif
shuffleCommittedInfo.commitPartitionRequests.clear()
Map.empty[WorkerInfo, Set[PartitionLocation]]
} else {
- val batch = new util.HashSet[CommitPartitionRequest]()
-
batch.addAll(shuffleCommittedInfo.commitPartitionRequests)
- val currentBatch = batch.asScala.filterNot { request =>
- shuffleCommittedInfo.handledCommitPartitionRequests
- .contains(request.partition)
- }
+ val currentBatch =
+
getUnCommitPartitionRequests(shuffleCommittedInfo.commitPartitionRequests)
shuffleCommittedInfo.commitPartitionRequests.clear()
currentBatch.foreach { commitPartitionRequest =>
shuffleCommittedInfo.handledCommitPartitionRequests
@@ -131,8 +191,7 @@ class CommitManager(appId: String, val conf: CelebornConf,
lifecycleManager: Lif
Seq(request.partition)
}
}.groupBy(_.getWorker)
- shuffleCommittedInfo.inFlightCommitRequest.addAndGet(
- workerToRequests.size)
+ incrementInflightNum(workerToRequests)
workerToRequests
} else {
Map.empty[WorkerInfo, Set[PartitionLocation]]
@@ -180,7 +239,7 @@ class CommitManager(appId: String, val conf: CelebornConf,
lifecycleManager: Lif
}
lifecycleManager.recordWorkerFailure(commitFilesFailedWorkers)
} finally {
-
shuffleCommittedInfo.inFlightCommitRequest.addAndGet(-workerToRequests.size)
+ decrementInflightNum(workerToRequests)
}
}
}
@@ -204,8 +263,8 @@ class CommitManager(appId: String, val conf: CelebornConf,
lifecycleManager: Lif
committedPartitionInfo.put(
shuffleId,
ShuffleCommittedInfo(
- new util.ArrayList[String](),
- new util.ArrayList[String](),
+ new ConcurrentHashMap[Int, util.List[String]](),
+ new ConcurrentHashMap[Int, util.List[String]](),
new ConcurrentHashMap[String, WorkerInfo](),
new ConcurrentHashMap[String, WorkerInfo](),
new ConcurrentHashMap[String, StorageInfo](),
@@ -214,11 +273,14 @@ class CommitManager(appId: String, val conf:
CelebornConf, lifecycleManager: Lif
new LongAdder,
new util.HashSet[CommitPartitionRequest](),
new util.HashSet[PartitionLocation](),
- new AtomicInteger()))
+ new AtomicInteger(),
+ new ConcurrentHashMap[Int, AtomicInteger]()))
}
def removeExpiredShuffle(shuffleId: Int): Unit = {
committedPartitionInfo.remove(shuffleId)
+ inProcessStageEndShuffleSet.remove(shuffleId)
+ inProcessMapPartitionEndIds.remove(shuffleId)
dataLostShuffleSet.remove(shuffleId)
stageEndShuffleSet.remove(shuffleId)
}
@@ -304,8 +366,21 @@ class CommitManager(appId: String, val conf: CelebornConf,
lifecycleManager: Lif
shuffleCommittedInfo.synchronized {
// record committed partitionIds
- shuffleCommittedInfo.committedMasterIds.addAll(res.committedMasterIds)
- shuffleCommittedInfo.committedSlaveIds.addAll(res.committedSlaveIds)
+ res.committedMasterIds.asScala.foreach({
+ case commitMasterId =>
+ val partitionUniqueIdList =
shuffleCommittedInfo.committedMasterIds.computeIfAbsent(
+ Utils.splitPartitionLocationUniqueId(commitMasterId)._1,
+ (k: Int) => new util.ArrayList[String]())
+ partitionUniqueIdList.add(commitMasterId)
+ })
+
+ res.committedSlaveIds.asScala.foreach({
+ case commitSlaveId =>
+ val partitionUniqueIdList =
shuffleCommittedInfo.committedSlaveIds.computeIfAbsent(
+ Utils.splitPartitionLocationUniqueId(commitSlaveId)._1,
+ (k: Int) => new util.ArrayList[String]())
+ partitionUniqueIdList.add(commitSlaveId)
+ })
// record committed partitions storage hint and disk hint
shuffleCommittedInfo.committedMasterStorageInfos.putAll(res.committedMasterStorageInfos)
@@ -359,7 +434,7 @@ class CommitManager(appId: String, val conf: CelebornConf,
lifecycleManager: Lif
def finalCommit(
applicationId: String,
shuffleId: Int,
- fileGroups: Array[Array[PartitionLocation]]): Unit = {
+ fileGroups: ConcurrentHashMap[Integer, util.Set[PartitionLocation]]):
Unit = {
if (stageEndShuffleSet.contains(shuffleId)) {
logInfo(s"[handleStageEnd] Shuffle $shuffleId already ended!")
return
@@ -372,10 +447,31 @@ class CommitManager(appId: String, val conf:
CelebornConf, lifecycleManager: Lif
inProcessStageEndShuffleSet.add(shuffleId)
}
// ask allLocations workers holding partitions to commit files
+ val allocatedWorkers =
lifecycleManager.shuffleAllocatedWorkers.get(shuffleId)
+ val dataLost = handleCommitFiles(applicationId, shuffleId,
allocatedWorkers, None, fileGroups)
+
+ // reply
+ if (!dataLost) {
+ logInfo(s"Succeed to handle stageEnd for $shuffleId.")
+ // record in stageEndShuffleSet
+ stageEndShuffleSet.add(shuffleId)
+ } else {
+ logError(s"Failed to handle stageEnd for $shuffleId, lost file!")
+ dataLostShuffleSet.add(shuffleId)
+ // record in stageEndShuffleSet
+ stageEndShuffleSet.add(shuffleId)
+ }
+ inProcessStageEndShuffleSet.remove(shuffleId)
+ }
+
+ private def handleCommitFiles(
+ applicationId: String,
+ shuffleId: Int,
+ allocatedWorkers: util.Map[WorkerInfo, PartitionLocationInfo],
+ partitionIdOpt: Option[Int] = None,
+ fileGroups: ConcurrentHashMap[Integer, util.Set[PartitionLocation]]):
Boolean = {
val masterPartMap = new ConcurrentHashMap[String, PartitionLocation]
val slavePartMap = new ConcurrentHashMap[String, PartitionLocation]
-
- val allocatedWorkers =
lifecycleManager.shuffleAllocatedWorkers.get(shuffleId)
val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
val commitFilesFailedWorkers = new ConcurrentHashMap[WorkerInfo,
(StatusCode, Long)]()
val commitFileStartTime = System.nanoTime()
@@ -386,8 +482,9 @@ class CommitManager(appId: String, val conf: CelebornConf,
lifecycleManager: Lif
"CommitFiles",
parallelism) { case (worker, partitionLocationInfo) =>
if (partitionLocationInfo.containsShuffle(shuffleId.toString)) {
- val masterParts =
partitionLocationInfo.getAllMasterLocations(shuffleId.toString)
- val slaveParts =
partitionLocationInfo.getAllSlaveLocations(shuffleId.toString)
+ val masterParts =
+ partitionLocationInfo.getMasterLocations(shuffleId.toString,
partitionIdOpt)
+ val slaveParts =
partitionLocationInfo.getSlaveLocations(shuffleId.toString, partitionIdOpt)
masterParts.asScala.foreach { p =>
val partition = new PartitionLocation(p)
partition.setFetchPort(worker.fetchPort)
@@ -421,59 +518,16 @@ class CommitManager(appId: String, val conf:
CelebornConf, lifecycleManager: Lif
commitFilesFailedWorkers)
}
}
+ lifecycleManager.recordWorkerFailure(commitFilesFailedWorkers)
+ // check all inflight request complete, for map partition, it's for single
partitionId
+ waitInflightRequestComplete(shuffleId, shuffleCommittedInfo,
partitionIdOpt)
- def hasCommitFailedIds: Boolean = {
- val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId)
- if (!pushReplicateEnabled &&
shuffleCommittedInfo.failedMasterPartitionIds.size() != 0) {
- val msg =
- shuffleCommittedInfo.failedMasterPartitionIds.asScala.map {
- case (partitionId, workerInfo) =>
- s"Lost partition $partitionId in worker
[${workerInfo.readableAddress()}]"
- }.mkString("\n")
- logError(
- s"""
- |For shuffle $shuffleKey partition data lost:
- |$msg
- |""".stripMargin)
- true
- } else {
- val failedBothPartitionIdsToWorker =
- shuffleCommittedInfo.failedMasterPartitionIds.asScala.flatMap {
- case (partitionId, worker) =>
- if
(shuffleCommittedInfo.failedSlavePartitionIds.contains(partitionId)) {
- Some(partitionId -> (worker,
shuffleCommittedInfo.failedSlavePartitionIds.get(
- partitionId)))
- } else {
- None
- }
- }
- if (failedBothPartitionIdsToWorker.nonEmpty) {
- val msg = failedBothPartitionIdsToWorker.map {
- case (partitionId, (masterWorker, slaveWorker)) =>
- s"Lost partition $partitionId " +
- s"in master worker [${masterWorker.readableAddress()}] and
slave worker [$slaveWorker]"
- }.mkString("\n")
- logError(
- s"""
- |For shuffle $shuffleKey partition data lost:
- |$msg
- |""".stripMargin)
- true
- } else {
- false
- }
- }
- }
-
- while (shuffleCommittedInfo.inFlightCommitRequest.get() > 0) {
- Thread.sleep(1000)
- }
-
- val dataLost = hasCommitFailedIds
+ // check all data lost or not, for map partition, it's for single
partitionId
+ val dataLost = checkDataLost(applicationId, shuffleId, partitionIdOpt)
if (!dataLost) {
val committedPartitions = new util.HashMap[String, PartitionLocation]
- shuffleCommittedInfo.committedMasterIds.asScala.foreach { id =>
+ getPartitionUniqueIds(shuffleCommittedInfo.committedMasterIds,
partitionIdOpt).foreach { id =>
if (shuffleCommittedInfo.committedMasterStorageInfos.get(id) == null) {
logDebug(s"$applicationId-$shuffleId $id storage hint was not
returned")
} else {
@@ -484,7 +538,7 @@ class CommitManager(appId: String, val conf: CelebornConf,
lifecycleManager: Lif
}
}
- shuffleCommittedInfo.committedSlaveIds.asScala.foreach { id =>
+ getPartitionUniqueIds(shuffleCommittedInfo.committedSlaveIds,
partitionIdOpt).foreach { id =>
val slavePartition = slavePartMap.get(id)
if (shuffleCommittedInfo.committedSlaveStorageInfos.get(id) == null) {
logDebug(s"$applicationId-$shuffleId $id storage hint was not
returned")
@@ -503,14 +557,11 @@ class CommitManager(appId: String, val conf:
CelebornConf, lifecycleManager: Lif
}
}
- val sets = Array.fill(fileGroups.length)(new
util.HashSet[PartitionLocation]())
committedPartitions.values().asScala.foreach { partition =>
- sets(partition.getId).add(partition)
- }
- var i = 0
- while (i < fileGroups.length) {
- fileGroups(i) = sets(i).toArray(new Array[PartitionLocation](0))
- i += 1
+ val partitionLocations = fileGroups.computeIfAbsent(
+ partition.getId,
+ (k: Integer) => new util.HashSet[PartitionLocation]())
+ partitionLocations.add(partition)
}
logInfo(s"Shuffle $shuffleId " +
@@ -518,25 +569,130 @@ class CommitManager(appId: String, val conf:
CelebornConf, lifecycleManager: Lif
s"using ${(System.nanoTime() - commitFileStartTime) / 1000000} ms")
}
- // reply
- if (!dataLost) {
- logInfo(s"Succeed to handle stageEnd for $shuffleId.")
- // record in stageEndShuffleSet
- stageEndShuffleSet.add(shuffleId)
+ dataLost
+ }
+
+ private def getPartitionIds(
+ partitionIds: ConcurrentHashMap[String, WorkerInfo],
+ partitionIdOpt: Option[Int]): util.Map[String, WorkerInfo] = {
+ partitionIdOpt match {
+ case Some(partitionId) => partitionIds.asScala.filter(p =>
+ Utils.splitPartitionLocationUniqueId(p._1)._1 ==
+ partitionId).asJava
+ case None => partitionIds
+ }
+ }
+
+ private def waitInflightRequestComplete(
+ shuffleId: Int,
+ shuffleCommittedInfo: ShuffleCommittedInfo,
+ partitionIdOpt: Option[Int]): Unit = {
+ lifecycleManager.getPartitionType(shuffleId) match {
+ case PartitionType.REDUCE =>
+ while (shuffleCommittedInfo.allInFlightCommitRequestNum.get() > 0) {
+ Thread.sleep(1000)
+ }
+ case PartitionType.MAP => partitionIdOpt match {
+ case Some(partitionId) =>
+ if
(shuffleCommittedInfo.partitionInFlightCommitRequestNum.containsKey(partitionId))
{
+ while
(shuffleCommittedInfo.partitionInFlightCommitRequestNum.get(
+ partitionId).get() > 0) {
+ Thread.sleep(1000)
+ }
+ }
+ }
+ }
+ }
+
+ private def checkDataLost(
+ applicationId: String,
+ shuffleId: Int,
+ partitionIdOpt: Option[Int]): Boolean = {
+ val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
+ val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId)
+ val masterPartitionUniqueIdMap =
+ getPartitionIds(shuffleCommittedInfo.failedMasterPartitionIds,
partitionIdOpt)
+ if (!pushReplicateEnabled && masterPartitionUniqueIdMap.size() != 0) {
+ val msg =
+ masterPartitionUniqueIdMap.asScala.map {
+ case (partitionUniqueId, workerInfo) =>
+ s"Lost partition $partitionUniqueId in worker
[${workerInfo.readableAddress()}]"
+ }.mkString("\n")
+ logError(
+ s"""
+ |For shuffle $shuffleKey partition data lost:
+ |$msg
+ |""".stripMargin)
+ true
} else {
- logError(s"Failed to handle stageEnd for $shuffleId, lost file!")
- dataLostShuffleSet.add(shuffleId)
- // record in stageEndShuffleSet
- stageEndShuffleSet.add(shuffleId)
+ val slavePartitionUniqueIdMap =
+ getPartitionIds(shuffleCommittedInfo.failedSlavePartitionIds,
partitionIdOpt)
+ val failedBothPartitionIdsToWorker =
masterPartitionUniqueIdMap.asScala.flatMap {
+ case (partitionUniqueId, worker) =>
+ if (slavePartitionUniqueIdMap.asScala.contains(partitionUniqueId)) {
+ Some(partitionUniqueId -> (worker,
slavePartitionUniqueIdMap.get(partitionUniqueId)))
+ } else {
+ None
+ }
+ }
+ if (failedBothPartitionIdsToWorker.nonEmpty) {
+ val msg = failedBothPartitionIdsToWorker.map {
+ case (partitionUniqueId, (masterWorker, slaveWorker)) =>
+ s"Lost partition $partitionUniqueId " +
+ s"in master worker [${masterWorker.readableAddress()}] and slave
worker [$slaveWorker]"
+ }.mkString("\n")
+ logError(
+ s"""
+ |For shuffle $shuffleKey partition data lost:
+ |$msg
+ |""".stripMargin)
+ true
+ } else {
+ false
+ }
}
- inProcessStageEndShuffleSet.remove(shuffleId)
- lifecycleManager.recordWorkerFailure(commitFilesFailedWorkers)
}
- def removeExpiredShuffle(shuffleId: String): Unit = {
- stageEndShuffleSet.remove(shuffleId)
- dataLostShuffleSet.remove(shuffleId)
- committedPartitionInfo.remove(shuffleId)
+ private def getPartitionUniqueIds(
+ ids: ConcurrentHashMap[Int, util.List[String]],
+ partitionIdOpt: Option[Int]): Iterable[String] = {
+ partitionIdOpt match {
+ case Some(partitionId) => ids.asScala.filter(_._1 ==
partitionId).flatMap(_._2.asScala)
+ case None => ids.asScala.flatMap(_._2.asScala)
+ }
+ }
+
+ def finalPartitionCommit(
+ applicationId: String,
+ shuffleId: Int,
+ fileGroups: ConcurrentHashMap[Integer, util.Set[PartitionLocation]],
+ partitionId: Int): Boolean = {
+ val inProcessingPartitionIds =
+ inProcessMapPartitionEndIds.computeIfAbsent(shuffleId, (k: Int) => new
util.HashSet[Int]())
+ inProcessingPartitionIds.add(partitionId)
+
+ val allocatedWorkers =
+ lifecycleManager.shuffleAllocatedWorkers.get(shuffleId).asScala.filter(p
=>
+ p._2.containsRelatedShuffleOrPartition(shuffleId.toString,
Option(partitionId))).asJava
+
+ var dataCommitSuccess = true
+ if (!allocatedWorkers.isEmpty) {
+ dataCommitSuccess =
+ !handleCommitFiles(
+ applicationId,
+ shuffleId,
+ allocatedWorkers,
+ Option(partitionId),
+ fileGroups)
+ }
+
+ // release resources and clear worker info
+ allocatedWorkers.asScala.foreach { case (_, partitionLocationInfo) =>
+ partitionLocationInfo.removeAllRelatedPartitions(shuffleId.toString,
Option(partitionId))
+ }
+ inProcessingPartitionIds.remove(partitionId)
+
+ dataCommitSuccess
}
def commitMetrics(): (Long, Long) = (totalWritten.sumThenReset(),
fileCount.sumThenReset())
diff --git
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 0ecce2ce..287ecd30 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -20,8 +20,7 @@ package org.apache.celeborn.client
import java.nio.ByteBuffer
import java.util
import java.util.{function, List => JList}
-import java.util.concurrent.{Callable, ConcurrentHashMap,
ScheduledExecutorService, ScheduledFuture, TimeUnit}
-import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, LongAdder}
+import java.util.concurrent.{Callable, ConcurrentHashMap, ScheduledFuture,
TimeUnit}
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -29,7 +28,6 @@ import scala.util.Random
import com.google.common.annotations.VisibleForTesting
import com.google.common.cache.{Cache, CacheBuilder}
-import org.roaringbitmap.RoaringBitmap
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.haclient.RssHARetryClient
@@ -43,6 +41,7 @@ import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.rpc._
import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext,
RemoteNettyRpcCallContext}
import org.apache.celeborn.common.util.{PbSerDeUtils, ThreadUtils, Utils}
+import org.apache.celeborn.common.util.FunctionConverter._
class LifecycleManager(appId: String, val conf: CelebornConf) extends
RpcEndpoint with Logging {
@@ -66,7 +65,7 @@ class LifecycleManager(appId: String, val conf: CelebornConf)
extends RpcEndpoin
val registeredShuffle = ConcurrentHashMap.newKeySet[Int]()
val shuffleMapperAttempts = new ConcurrentHashMap[Int, Array[Int]]()
private val reducerFileGroupsMap =
- new ConcurrentHashMap[Int, Array[Array[PartitionLocation]]]()
+ new ConcurrentHashMap[Int, ConcurrentHashMap[Integer,
util.Set[PartitionLocation]]]()
private val shuffleTaskInfo = new ShuffleTaskInfo()
// maintain each shuffle's map relation of WorkerInfo and partition location
val shuffleAllocatedWorkers =
@@ -310,10 +309,16 @@ class LifecycleManager(appId: String, val conf:
CelebornConf) extends RpcEndpoin
epoch,
oldPartition)
- case MapperEnd(applicationId, shuffleId, mapId, attemptId, numMappers) =>
- logTrace(s"Received MapperEnd request, " +
- s"${Utils.makeMapKey(applicationId, shuffleId, mapId, attemptId)}.")
- handleMapperEnd(context, applicationId, shuffleId, mapId, attemptId,
numMappers)
+ case MapperEnd(applicationId, shuffleId, mapId, attemptId, numMappers,
partitionId) =>
+ logTrace(s"Received MapperEnd TaskEnd request, " +
+ s"${Utils.makeMapKey(applicationId, shuffleId, mapId, attemptId)}")
+ val partitionType = getPartitionType(shuffleId)
+ partitionType match {
+ case PartitionType.REDUCE =>
+ handleMapperEnd(context, applicationId, shuffleId, mapId, attemptId,
numMappers)
+ case PartitionType.MAP =>
+ handleMapPartitionEnd(context, applicationId, shuffleId, mapId,
attemptId, partitionId)
+ }
case GetReducerFileGroup(applicationId: String, shuffleId: Int) =>
logDebug(s"Received GetShuffleFileGroup request," +
@@ -483,8 +488,7 @@ class LifecycleManager(appId: String, val conf:
CelebornConf) extends RpcEndpoin
}
}
}
-
- reducerFileGroupsMap.put(shuffleId, new
Array[Array[PartitionLocation]](numReducers))
+ reducerFileGroupsMap.put(shuffleId, new ConcurrentHashMap())
// Fifth, reply the allocated partition location to ShuffleClient.
logInfo(s"Handle RegisterShuffle Success for $shuffleId.")
@@ -551,13 +555,15 @@ class LifecycleManager(appId: String, val conf:
CelebornConf) extends RpcEndpoin
return
}
- // If shuffle registered and corresponding map finished, reply MapEnd and
return.
- if (shuffleMapperAttempts.containsKey(shuffleId)
- && shuffleMapperAttempts.get(shuffleId)(mapId) != -1) {
- logWarning(s"[handleRevive] Mapper ended, mapId $mapId, current
attemptId $attemptId, " +
- s"ended attemptId ${shuffleMapperAttempts.get(shuffleId)(mapId)},
shuffleId $shuffleId.")
- context.reply(ChangeLocationResponse(StatusCode.MAP_ENDED, None))
- return
+ // If shuffle registered and corresponding map finished, reply MapEnd and
return. Only for reduce partition type
+ if (getPartitionType(shuffleId) == PartitionType.REDUCE) {
+ if (shuffleMapperAttempts.containsKey(shuffleId) &&
shuffleMapperAttempts.get(shuffleId)(
+ mapId) != -1) {
+ logWarning(s"[handleRevive] Mapper ended, mapId $mapId, current
attemptId $attemptId, " +
+ s"ended attemptId ${shuffleMapperAttempts.get(shuffleId)(mapId)},
shuffleId $shuffleId.")
+ context.reply(ChangeLocationResponse(StatusCode.MAP_ENDED, None))
+ return
+ }
}
logWarning(s"Do Revive for shuffle ${Utils.makeShuffleKey(applicationId,
shuffleId)}, " +
@@ -621,28 +627,39 @@ class LifecycleManager(appId: String, val conf:
CelebornConf) extends RpcEndpoin
shuffleId: Int): Unit = {
var timeout = stageEndTimeout
val delta = 100
- while (!commitManager.stageEndShuffleSet.contains(shuffleId)) {
- Thread.sleep(delta)
- if (timeout <= 0) {
- logError(s"[handleGetReducerFileGroup] Wait for handleStageEnd
Timeout! $shuffleId.")
- context.reply(
- GetReducerFileGroupResponse(StatusCode.STAGE_END_TIME_OUT,
Array.empty, Array.empty))
- return
+ // reduce partition need wait stage end. While map partition Would commit
every partition synchronously.
+ if (getPartitionType(shuffleId) == PartitionType.REDUCE) {
+ while (!commitManager.stageEndShuffleSet.contains(shuffleId)) {
+ Thread.sleep(delta)
+ if (timeout <= 0) {
+ logError(s"[handleGetReducerFileGroup] Wait for handleStageEnd
Timeout! $shuffleId.")
+ context.reply(
+ GetReducerFileGroupResponse(
+ StatusCode.STAGE_END_TIME_OUT,
+ new ConcurrentHashMap(),
+ Array.empty))
+ return
+ }
+ timeout = timeout - delta
}
- timeout = timeout - delta
+ logDebug("[handleGetReducerFileGroup] Wait for handleStageEnd complete
cost" +
+ s" ${stageEndTimeout - timeout}ms")
}
- logDebug("[handleGetReducerFileGroup] Wait for handleStageEnd complete
cost" +
- s" ${stageEndTimeout - timeout}ms")
if (commitManager.dataLostShuffleSet.contains(shuffleId)) {
context.reply(
- GetReducerFileGroupResponse(StatusCode.SHUFFLE_DATA_LOST, Array.empty,
Array.empty))
+ GetReducerFileGroupResponse(
+ StatusCode.SHUFFLE_DATA_LOST,
+ new ConcurrentHashMap(),
+ Array.empty))
} else {
if (context.isInstanceOf[LocalNettyRpcCallContext]) {
// This branch is for the UTs
context.reply(GetReducerFileGroupResponse(
StatusCode.SUCCESS,
- reducerFileGroupsMap.getOrDefault(shuffleId, Array.empty),
+ reducerFileGroupsMap.getOrDefault(
+ shuffleId,
+ new ConcurrentHashMap()),
shuffleMapperAttempts.getOrDefault(shuffleId, Array.empty)))
} else {
val cachedMsg = getReducerFileGroupRpcCache.get(
@@ -651,7 +668,9 @@ class LifecycleManager(appId: String, val conf:
CelebornConf) extends RpcEndpoin
override def call(): ByteBuffer = {
val returnedMsg = GetReducerFileGroupResponse(
StatusCode.SUCCESS,
- reducerFileGroupsMap.getOrDefault(shuffleId, Array.empty),
+ reducerFileGroupsMap.getOrDefault(
+ shuffleId,
+ new ConcurrentHashMap()),
shuffleMapperAttempts.getOrDefault(shuffleId, Array.empty))
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
}
@@ -684,24 +703,55 @@ class LifecycleManager(appId: String, val conf:
CelebornConf) extends RpcEndpoin
ReleaseSlots(applicationId, shuffleId, List.empty.asJava,
List.empty.asJava))
}
+ private def handleMapPartitionEnd(
+ context: RpcCallContext,
+ applicationId: String,
+ shuffleId: Int,
+ mapId: Int,
+ attemptId: Int,
+ partitionId: Int): Unit = {
+ def reply(result: Boolean): Unit = {
+ val message =
+ s"to handle MapPartitionEnd for ${Utils.makeMapKey(appId, shuffleId,
mapId, attemptId)}, " +
+ s"$partitionId.";
+ result match {
+ case true => // if already committed by another try
+ logDebug(s"Succeed $message")
+ context.reply(MapperEndResponse(StatusCode.SUCCESS))
+ case false =>
+ logError(s"Failed $message")
+ context.reply(MapperEndResponse(StatusCode.SHUFFLE_DATA_LOST))
+ }
+ }
+
+ val dataCommitSuccess = commitManager.finalPartitionCommit(
+ applicationId,
+ shuffleId,
+ reducerFileGroupsMap.get(shuffleId),
+ partitionId)
+ reply(dataCommitSuccess)
+ }
+
private def handleUnregisterShuffle(
appId: String,
shuffleId: Int): Unit = {
- // if StageEnd has not been handled, trigger StageEnd
- if (!commitManager.stageEndShuffleSet.contains(shuffleId)) {
- logInfo(s"Call StageEnd before Unregister Shuffle $shuffleId.")
- handleStageEnd(appId, shuffleId)
- var timeout = stageEndTimeout
- val delta = 100
- while (!commitManager.stageEndShuffleSet.contains(shuffleId) && timeout
> 0) {
- Thread.sleep(delta)
- timeout = timeout - delta
- }
- if (timeout <= 0) {
- logError(s"StageEnd Timeout! $shuffleId.")
- } else {
- logInfo("[handleUnregisterShuffle] Wait for handleStageEnd complete
cost" +
- s" ${stageEndTimeout - timeout}ms")
+ if (getPartitionType(shuffleId) == PartitionType.REDUCE) {
+ // if StageEnd has not been handled, trigger StageEnd
+ if (!commitManager.stageEndShuffleSet.contains(shuffleId)) {
+ logInfo(s"Call StageEnd before Unregister Shuffle $shuffleId.")
+ handleStageEnd(appId, shuffleId)
+ var timeout = stageEndTimeout
+ val delta = 100
+ while (!commitManager.stageEndShuffleSet.contains(shuffleId) &&
timeout > 0) {
+ Thread.sleep(delta)
+ timeout = timeout - delta
+ }
+ if (timeout <= 0) {
+ logError(s"StageEnd Timeout! $shuffleId.")
+ } else {
+ logInfo("[handleUnregisterShuffle] Wait for handleStageEnd complete
cost" +
+ s" ${stageEndTimeout - timeout}ms")
+ }
}
}
diff --git
a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
index 2fb60718..95470f6c 100644
--- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -91,6 +91,16 @@ public class DummyShuffleClient extends ShuffleClient {
public void mapperEnd(
String applicationId, int shuffleId, int mapId, int attemptId, int
numMappers) {}
+ @Override
+ public void mapPartitionMapperEnd(
+ String applicationId,
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int partitionId)
+ throws IOException {}
+
@Override
public void cleanup(String applicationId, int shuffleId, int mapId, int
attemptId) {}
diff --git a/common/src/main/proto/TransportMessages.proto
b/common/src/main/proto/TransportMessages.proto
index 2cf93798..ae4f3816 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -232,6 +232,7 @@ message PbMapperEnd {
int32 mapId = 3;
int32 attemptId = 4;
int32 numMappers = 5;
+ int32 partitionId = 6;
}
message PbMapperEndResponse {
@@ -245,7 +246,10 @@ message PbGetReducerFileGroup {
message PbGetReducerFileGroupResponse {
int32 status = 1;
- repeated PbFileGroup fileGroup = 2;
+ // PartitionId -> Partition FileGroup
+ map<int32, PbFileGroup> fileGroups = 2;
+
+ // only reduce partition mode need know valid attempts
repeated int32 attempts = 3;
}
diff --git
a/common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala
b/common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala
index 4f29c63a..78550907 100644
---
a/common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/meta/PartitionLocationInfo.scala
@@ -46,6 +46,24 @@ class PartitionLocationInfo extends Logging {
slavePartitionLocations.containsKey(shuffleKey)
}
+ def containsRelatedShuffleOrPartition(shuffleKey: String, partitionIdOpt:
Option[Int]): Boolean =
+ this.synchronized {
+ partitionIdOpt match {
+ case Some(partitionId) =>
+ containsPartition(shuffleKey, partitionId)
+ case None => containsShuffle(shuffleKey)
+ }
+ }
+
+ private def containsPartition(shuffleKey: String, partitionId: Int): Boolean
= this
+ .synchronized {
+ val contain = masterPartitionLocations.containsKey(
+ shuffleKey) &&
masterPartitionLocations.get(shuffleKey).containsKey(partitionId)
+ contain || (slavePartitionLocations.containsKey(shuffleKey) &&
slavePartitionLocations.get(
+ shuffleKey)
+ .containsKey(partitionId))
+ }
+
def addMasterPartition(shuffleKey: String, location: PartitionLocation): Int
= {
addPartition(shuffleKey, location, masterPartitionLocations)
}
@@ -75,26 +93,44 @@ class PartitionLocationInfo extends Logging {
}
def getAllMasterLocations(shuffleKey: String): util.List[PartitionLocation]
= this.synchronized {
- if (masterPartitionLocations.containsKey(shuffleKey)) {
- masterPartitionLocations.get(shuffleKey)
- .values()
- .asScala
- .flatMap(_.asScala)
- .toList
- .asJava
- } else {
- new util.ArrayList[PartitionLocation]()
- }
+ getMasterLocations(shuffleKey)
}
def getAllSlaveLocations(shuffleKey: String): util.List[PartitionLocation] =
this.synchronized {
- if (slavePartitionLocations.containsKey(shuffleKey)) {
- slavePartitionLocations.get(shuffleKey)
- .values()
- .asScala
- .flatMap(_.asScala)
- .toList
- .asJava
+ getSlaveLocations(shuffleKey)
+ }
+
+ def getMasterLocations(
+ shuffleKey: String,
+ partitionIdOpt: Option[Int] = None): util.List[PartitionLocation] = {
+ getLocations(shuffleKey, masterPartitionLocations, partitionIdOpt)
+ }
+
+ def getSlaveLocations(
+ shuffleKey: String,
+ partitionIdOpt: Option[Int] = None): util.List[PartitionLocation] = {
+ getLocations(shuffleKey, slavePartitionLocations, partitionIdOpt)
+ }
+
+ private def getLocations(
+ shuffleKey: String,
+ partitionInfo: PartitionInfo,
+ partitionIdOpt: Option[Int] = None): util.List[PartitionLocation] =
this.synchronized {
+ if (partitionInfo.containsKey(shuffleKey)) {
+ partitionIdOpt match {
+ case Some(partitionId) => partitionInfo.get(shuffleKey)
+ .values()
+ .asScala
+ .flatMap(_.asScala)
+ .filter(_.getId == partitionId)
+ .toList.asJava
+ case None =>
+ partitionInfo.get(shuffleKey)
+ .values()
+ .asScala
+ .flatMap(_.asScala)
+ .toList.asJava
+ }
} else {
new util.ArrayList[PartitionLocation]()
}
@@ -201,6 +237,24 @@ class PartitionLocationInfo extends Logging {
}
}
+ def removeAllRelatedPartitions(
+ shuffleKey: String,
+ partitionIdOpt: Option[Int]): Unit = this
+ .synchronized {
+ partitionIdOpt match {
+ case Some(partitionId) =>
+ if (masterPartitionLocations.containsKey(shuffleKey)) {
+ masterPartitionLocations.get(shuffleKey).remove(partitionId)
+ }
+ if (slavePartitionLocations.containsKey(shuffleKey)) {
+ slavePartitionLocations.get(shuffleKey).remove(partitionId)
+ }
+ case None =>
+ removeMasterPartitions(shuffleKey)
+ removeSlavePartitions(shuffleKey)
+ }
+ }
+
/**
* @param shuffleKey
* @param uniqueIds
diff --git
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 9697aabb..d8255005 100644
---
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -242,7 +242,8 @@ object ControlMessages extends Logging {
shuffleId: Int,
mapId: Int,
attemptId: Int,
- numMappers: Int)
+ numMappers: Int,
+ partitionId: Int)
extends MasterMessage
case class MapperEndResponse(status: StatusCode) extends MasterMessage
@@ -253,7 +254,7 @@ object ControlMessages extends Logging {
// Path can't be serialized
case class GetReducerFileGroupResponse(
status: StatusCode,
- fileGroup: Array[Array[PartitionLocation]],
+ fileGroup: util.Map[Integer, util.Set[PartitionLocation]],
attempts: Array[Int])
extends MasterMessage
@@ -521,13 +522,14 @@ object ControlMessages extends Logging {
case pb: PbChangeLocationResponse =>
new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE,
pb.toByteArray)
- case MapperEnd(applicationId, shuffleId, mapId, attemptId, numMappers) =>
+ case MapperEnd(applicationId, shuffleId, mapId, attemptId, numMappers,
partitionId) =>
val payload = PbMapperEnd.newBuilder()
.setApplicationId(applicationId)
.setShuffleId(shuffleId)
.setMapId(mapId)
.setAttemptId(attemptId)
.setNumMappers(numMappers)
+ .setPartitionId(partitionId)
.build().toByteArray
new TransportMessage(MessageType.MAPPER_END, payload)
@@ -547,13 +549,13 @@ object ControlMessages extends Logging {
val builder = PbGetReducerFileGroupResponse
.newBuilder()
.setStatus(status.getValue)
- builder.addAllFileGroup(
- fileGroup.map { arr =>
- PbFileGroup.newBuilder().addAllLocations(arr
- .map(PbSerDeUtils.toPbPartitionLocation).toIterable.asJava).build()
- }
- .toIterable
- .asJava)
+ builder.putAllFileGroups(
+ fileGroup.asScala.map { case (partitionId, fileGroup) =>
+ (
+ partitionId,
+
PbFileGroup.newBuilder().addAllLocations(fileGroup.asScala.map(PbSerDeUtils
+ .toPbPartitionLocation).toList.asJava).build())
+ }.asJava)
builder.addAllAttempts(attempts.map(new Integer(_)).toIterable.asJava)
val payload = builder.build().toByteArray
new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE,
payload)
@@ -875,7 +877,8 @@ object ControlMessages extends Logging {
pbMapperEnd.getShuffleId,
pbMapperEnd.getMapId,
pbMapperEnd.getAttemptId,
- pbMapperEnd.getNumMappers)
+ pbMapperEnd.getNumMappers,
+ pbMapperEnd.getPartitionId)
case MAPPER_END_RESPONSE =>
val pbMapperEndResponse =
PbMapperEndResponse.parseFrom(message.getPayload)
@@ -890,9 +893,14 @@ object ControlMessages extends Logging {
case GET_REDUCER_FILE_GROUP_RESPONSE =>
val pbGetReducerFileGroupResponse = PbGetReducerFileGroupResponse
.parseFrom(message.getPayload)
- val fileGroup =
pbGetReducerFileGroupResponse.getFileGroupList.asScala.map { fg =>
-
fg.getLocationsList.asScala.map(PbSerDeUtils.fromPbPartitionLocation).toArray
- }.toArray
+ val fileGroup =
pbGetReducerFileGroupResponse.getFileGroupsMap.asScala.map {
+ case (partitionId, fileGroup) =>
+ (
+ partitionId,
+ fileGroup.getLocationsList.asScala.map(
+ PbSerDeUtils.fromPbPartitionLocation).toSet.asJava)
+ }.asJava
+
val attempts =
pbGetReducerFileGroupResponse.getAttemptsList.asScala.map(_.toInt).toArray
GetReducerFileGroupResponse(
Utils.toStatusCode(pbGetReducerFileGroupResponse.getStatus),
diff --git
a/common/src/main/scala/org/apache/celeborn/common/util/FunctionConverter.scala
b/common/src/main/scala/org/apache/celeborn/common/util/FunctionConverter.scala
new file mode 100644
index 00000000..71c75cf2
--- /dev/null
+++
b/common/src/main/scala/org/apache/celeborn/common/util/FunctionConverter.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.util
+
+/**
+ * Implicit conversion for scala(2.11) function to java function
+ */
+object FunctionConverter {
+
+ implicit def scalaFunctionToJava[From, To](function: (From) => To)
+ : java.util.function.Function[From, To] = {
+ new java.util.function.Function[From, To] {
+ override def apply(input: From): To = function(input)
+ }
+ }
+
+}
diff --git
a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
index d631d7dc..70ecc84b 100644
--- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
+++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
@@ -17,7 +17,12 @@
package org.apache.celeborn.common.util
+import java.util
+
import org.apache.celeborn.RssFunSuite
+import org.apache.celeborn.common.protocol.PartitionLocation
+import
org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse,
MapperEnd}
+import org.apache.celeborn.common.protocol.message.StatusCode
class UtilsSuite extends RssFunSuite {
@@ -92,4 +97,45 @@ class UtilsSuite extends RssFunSuite {
test("getThreadDump") {
assert(Utils.getThreadDump().nonEmpty)
}
+
+ test("MapperEnd class convert with pb") {
+ val mapperEnd = MapperEnd("application1", 1, 1, 1, 2, 1)
+ val mapperEndTrans =
+
Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd]
+ assert(mapperEnd == mapperEndTrans)
+ }
+
+ test("GetReducerFileGroupResponse class convert with pb") {
+ val fileGroup = new util.HashMap[Integer, util.Set[PartitionLocation]]
+ fileGroup.put(0, partitionLocation(0))
+ fileGroup.put(1, partitionLocation(1))
+ fileGroup.put(2, partitionLocation(2))
+
+ val attempts = Array(0, 0, 1)
+ val response = GetReducerFileGroupResponse(StatusCode.STAGE_ENDED,
fileGroup, attempts)
+ val responseTrans =
Utils.fromTransportMessage(Utils.toTransportMessage(response)).asInstanceOf[
+ GetReducerFileGroupResponse]
+
+ assert(response.status == responseTrans.status)
+ assert(response.attempts.deep == responseTrans.attempts.deep)
+ val set =
+ (response.fileGroup.values().toArray diff
responseTrans.fileGroup.values().toArray).toSet
+ assert(set.size == 0)
+ }
+
+ def partitionLocation(partitionId: Int): util.HashSet[PartitionLocation] = {
+ val partitionSet = new util.HashSet[PartitionLocation]
+ for (i <- 0 until 3) {
+ partitionSet.add(new PartitionLocation(
+ partitionId,
+ i,
+ "host",
+ 100,
+ 1000,
+ 1001,
+ 100,
+ PartitionLocation.Mode.MASTER))
+ }
+ partitionSet
+ }
}
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
index 4e857e17..926210ff 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
@@ -36,6 +36,9 @@ class ShuffleClientSuite extends AnyFunSuite with
MiniClusterFeature
val APP = "app-1"
var shuffleClient: ShuffleClientImpl = _
var lifecycleManager: LifecycleManager = _
+ val numMappers = 8
+ val mapId = 1
+ val attemptId = 0
override def beforeAll(): Unit = {
val masterConf = Map(
@@ -54,11 +57,8 @@ class ShuffleClientSuite extends AnyFunSuite with
MiniClusterFeature
shuffleClient.setupMetaServiceRef(lifecycleManager.self)
}
- test(s"test register map partition task with first attemptId") {
+ test(s"test register map partition task") {
val shuffleId = 1
- val numMappers = 8
- val mapId = 1
- val attemptId = 0
var location =
shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers,
mapId, attemptId)
Assert.assertEquals(location.getId,
PackedPartitionId.packedPartitionId(mapId, attemptId))
@@ -93,6 +93,39 @@ class ShuffleClientSuite extends AnyFunSuite with
MiniClusterFeature
Assert.assertEquals(count, numMappers + 1)
}
+ test(s"test map end & get reducer file group") {
+ val shuffleId = 2
+ shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId,
attemptId)
+ shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId +
1, attemptId)
+ shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId +
2, attemptId)
+ shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId,
attemptId + 1)
+ shuffleClient.mapPartitionMapperEnd(APP, shuffleId, numMappers, mapId,
attemptId, mapId)
+ // retry
+ shuffleClient.mapPartitionMapperEnd(APP, shuffleId, numMappers, mapId,
attemptId, mapId)
+ // another attempt
+ shuffleClient.mapPartitionMapperEnd(
+ APP,
+ shuffleId,
+ numMappers,
+ mapId,
+ attemptId + 1,
+ PackedPartitionId
+ .packedPartitionId(mapId, attemptId + 1))
+ // another mapper
+ shuffleClient.mapPartitionMapperEnd(APP, shuffleId, numMappers, mapId + 1,
attemptId, mapId + 1)
+
+ // reduce file group size (for empty partitions)
+ Assert.assertEquals(shuffleClient.getReduceFileGroupsMap.size(), 0)
+
+ // reduce normal empty RssInputStream
+ var stream = shuffleClient.readPartition(APP, shuffleId, 1, 1)
+ Assert.assertEquals(stream.read(), -1)
+
+ // reduce normal null partition for RssInputStream
+ stream = shuffleClient.readPartition(APP, shuffleId, 3, 1)
+ Assert.assertEquals(stream.read(), -1)
+ }
+
override def afterAll(): Unit = {
// TODO refactor MiniCluster later
println("test done")
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
index 932998e7..97c333a3 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
@@ -50,8 +50,16 @@ class HugeDataTest extends AnyFunSuite
test("celeborn spark integration test - huge data") {
val sparkConf = new
SparkConf().setAppName("rss-demo").setMaster("local[4]")
val ss = SparkSession.builder().config(updateSparkConf(sparkConf,
false)).getOrCreate()
- ss.sparkContext.parallelize(1 to 10000, 2)
- .map { i => (i, Range(1, 10000).mkString(",")) }.groupByKey(16).collect()
+ val value = Range(1, 10000).mkString(",")
+ val tuples = ss.sparkContext.parallelize(1 to 10000, 2)
+ .map { i => (i, value) }.groupByKey(16).collect()
+
+ // verify result
+ assert(tuples.length == 10000)
+ for (elem <- tuples) {
+ assert(elem._2.mkString(",").equals(value))
+ }
+
ss.stop()
}
}