This is an automated email from the ASF dual-hosted git repository.
chengpan pushed a commit to branch branch-0.3
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/branch-0.3 by this push:
new 6a8f45b8c [CELEBORN-756][WORKER] Refactor `PushDataHandler` class to
utilize `while` loop
6a8f45b8c is described below
commit 6a8f45b8c97a73edb7c620df878f8ba5e6676227
Author: Fu Chen <[email protected]>
AuthorDate: Fri Jun 30 18:13:53 2023 +0800
[CELEBORN-756][WORKER] Refactor `PushDataHandler` class to utilize `while`
loop
### What changes were proposed in this pull request?
as title
### Why are the changes needed?
per
https://github.com/databricks/scala-style-guide#traversal-and-zipwithindex, use
`while` loop for performance-sensitive code
worker's flame graph before:

after:

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Pass GA
Closes #1668 from cfmcgrady/while-loop-2.
Authored-by: Fu Chen <[email protected]>
Signed-off-by: Cheng Pan <[email protected]>
(cherry picked from commit 047e90b17bdae4405bdd12f254646b1d5392aaee)
Signed-off-by: Cheng Pan <[email protected]>
---
.../common/meta/WorkerPartitionLocationInfo.scala | 26 ++++++++++++
.../service/deploy/worker/PushDataHandler.scala | 46 ++++++++++++++++------
2 files changed, 60 insertions(+), 12 deletions(-)
diff --git
a/common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala
b/common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala
index 7dcaaf992..306c4c1e7 100644
---
a/common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala
@@ -62,10 +62,36 @@ class WorkerPartitionLocationInfo extends Logging {
getLocation(shuffleKey, uniqueId, primaryPartitionLocations)
}
+ def getPrimaryLocations(
+ shuffleKey: String,
+ uniqueIds: Array[String]): Array[(String, PartitionLocation)] = {
+ val locations = new Array[(String, PartitionLocation)](uniqueIds.length)
+ var i = 0
+ while (i < uniqueIds.length) {
+ val uniqueId = uniqueIds(i)
+ locations(i) = uniqueId -> getPrimaryLocation(shuffleKey, uniqueId)
+ i += 1
+ }
+ locations
+ }
+
def getReplicaLocation(shuffleKey: String, uniqueId: String):
PartitionLocation = {
getLocation(shuffleKey, uniqueId, replicaPartitionLocations)
}
+ def getReplicaLocations(
+ shuffleKey: String,
+ uniqueIds: Array[String]): Array[(String, PartitionLocation)] = {
+ val locations = new Array[(String, PartitionLocation)](uniqueIds.length)
+ var i = 0
+ while (i < uniqueIds.length) {
+ val uniqueId = uniqueIds(i)
+ locations(i) = uniqueId -> getReplicaLocation(shuffleKey, uniqueId)
+ i += 1
+ }
+ locations
+ }
+
def removeShuffle(shuffleKey: String): Unit = {
primaryPartitionLocations.remove(shuffleKey)
replicaPartitionLocations.remove(shuffleKey)
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index 627104e9a..12783265f 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -417,20 +417,21 @@ class PushDataHandler extends BaseMessageHandler with
Logging {
return
}
- val partitionIdToLocations = pushMergedData.partitionUniqueIds.map { id =>
+ val partitionIdToLocations =
if (isPrimary) {
- id -> partitionLocationInfo.getPrimaryLocation(shuffleKey, id)
+ partitionLocationInfo.getPrimaryLocations(shuffleKey,
pushMergedData.partitionUniqueIds)
} else {
- id -> partitionLocationInfo.getReplicaLocation(shuffleKey, id)
+ partitionLocationInfo.getReplicaLocations(shuffleKey,
pushMergedData.partitionUniqueIds)
}
- }
// Fetch real batchId from body will add more cost and no meaning for
replicate.
val doReplicate =
partitionIdToLocations.head._2 != null &&
partitionIdToLocations.head._2.hasPeer && isPrimary
// find FileWriters responsible for the data
- partitionIdToLocations.foreach { case (id, loc) =>
+ var index = 0
+ while (index < partitionIdToLocations.length) {
+ val (id, loc) = partitionIdToLocations(index)
if (loc == null) {
val (mapId, attemptId) = getMapAttempt(body)
// MapperAttempts for a shuffle exists after any CommitFiles request
succeeds.
@@ -468,6 +469,7 @@ class PushDataHandler extends BaseMessageHandler with
Logging {
}
return
}
+ index += 1
}
// During worker shutdown, worker will return HARD_SPLIT for all existed
partition.
@@ -477,11 +479,9 @@ class PushDataHandler extends BaseMessageHandler with
Logging {
return
}
- val fileWriters =
-
partitionIdToLocations.map(_._2).map(_.asInstanceOf[WorkingPartition].getFileWriter)
- val fileWriterWithException = fileWriters.find(_.getException != null)
- if (fileWriterWithException.nonEmpty) {
- val exception = fileWriterWithException.get.getException
+ val (fileWriters, exceptionFileWriterIndexOpt) =
getFileWriters(partitionIdToLocations)
+ if (exceptionFileWriterIndexOpt.isDefined) {
+ val fileWriterWithException =
fileWriters(exceptionFileWriterIndexOpt.get)
val cause =
if (isPrimary) {
StatusCode.PUSH_DATA_WRITE_FAIL_PRIMARY
@@ -490,7 +490,7 @@ class PushDataHandler extends BaseMessageHandler with
Logging {
}
logError(
s"While handling PushMergedData, throw $cause, fileWriter
$fileWriterWithException has exception.",
- exception)
+ fileWriterWithException.getException)
workerSource.incCounter(WorkerSource.WriteDataFailCount)
callbackWithTimer.onFailure(new CelebornIOException(cause))
return
@@ -617,7 +617,7 @@ class PushDataHandler extends BaseMessageHandler with
Logging {
}
}
- var index = 0
+ index = 0
var fileWriter: FileWriter = null
var alreadyClosed = false
while (index < fileWriters.length) {
@@ -656,6 +656,28 @@ class PushDataHandler extends BaseMessageHandler with
Logging {
}
}
+ /**
+ * returns an array of FileWriters from partition locations along with an
optional index for any FileWriter that
+ * encountered an exception.
+ */
+ private def getFileWriters(
+ partitionIdToLocations: Array[(String, PartitionLocation)])
+ : (Array[FileWriter], Option[Int]) = {
+ val fileWriters = new Array[FileWriter](partitionIdToLocations.length)
+ var i = 0
+ var exceptionFileWriterIndex: Option[Int] = None
+ while (i < partitionIdToLocations.length) {
+ val (_, workingPartition) = partitionIdToLocations(i)
+ val fileWriter =
workingPartition.asInstanceOf[WorkingPartition].getFileWriter
+ if (fileWriter.getException != null) {
+ exceptionFileWriterIndex = Some(i)
+ }
+ fileWriters(i) = fileWriter
+ i += 1
+ }
+ (fileWriters, exceptionFileWriterIndex)
+ }
+
private def getMapAttempt(body: ByteBuf): (Int, Int) = {
// header: mapId attemptId batchId compressedTotalSize
val header = new Array[Byte](8)