This is an automated email from the ASF dual-hosted git repository.

chengpan pushed a commit to branch branch-0.3
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/branch-0.3 by this push:
     new 6a8f45b8c [CELEBORN-756][WORKER] Refactor `PushDataHandler` class to 
utilize `while` loop
6a8f45b8c is described below

commit 6a8f45b8c97a73edb7c620df878f8ba5e6676227
Author: Fu Chen <[email protected]>
AuthorDate: Fri Jun 30 18:13:53 2023 +0800

    [CELEBORN-756][WORKER] Refactor `PushDataHandler` class to utilize `while` 
loop
    
    ### What changes were proposed in this pull request?
    
    as title
    
    ### Why are the changes needed?
    
    per 
https://github.com/databricks/scala-style-guide#traversal-and-zipwithindex, use 
`while` loop for performance-sensitive code
    
    worker's flame graph before:
    
    ![截屏2023-06-30 下午5 58 
02](https://github.com/apache/incubator-celeborn/assets/8537877/28c199b6-a29b-4501-8064-e0f2ddb2a8b9)
    
    after:
    
    ![截屏2023-06-30 下午5 58 
18](https://github.com/apache/incubator-celeborn/assets/8537877/c6134959-5f78-436b-aa29-a78882b09e84)
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Pass GA
    
    Closes #1668 from cfmcgrady/while-loop-2.
    
    Authored-by: Fu Chen <[email protected]>
    Signed-off-by: Cheng Pan <[email protected]>
    (cherry picked from commit 047e90b17bdae4405bdd12f254646b1d5392aaee)
    Signed-off-by: Cheng Pan <[email protected]>
---
 .../common/meta/WorkerPartitionLocationInfo.scala  | 26 ++++++++++++
 .../service/deploy/worker/PushDataHandler.scala    | 46 ++++++++++++++++------
 2 files changed, 60 insertions(+), 12 deletions(-)

diff --git 
a/common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala
 
b/common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala
index 7dcaaf992..306c4c1e7 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala
@@ -62,10 +62,36 @@ class WorkerPartitionLocationInfo extends Logging {
     getLocation(shuffleKey, uniqueId, primaryPartitionLocations)
   }
 
+  def getPrimaryLocations(
+      shuffleKey: String,
+      uniqueIds: Array[String]): Array[(String, PartitionLocation)] = {
+    val locations = new Array[(String, PartitionLocation)](uniqueIds.length)
+    var i = 0
+    while (i < uniqueIds.length) {
+      val uniqueId = uniqueIds(i)
+      locations(i) = uniqueId -> getPrimaryLocation(shuffleKey, uniqueId)
+      i += 1
+    }
+    locations
+  }
+
   def getReplicaLocation(shuffleKey: String, uniqueId: String): 
PartitionLocation = {
     getLocation(shuffleKey, uniqueId, replicaPartitionLocations)
   }
 
+  def getReplicaLocations(
+      shuffleKey: String,
+      uniqueIds: Array[String]): Array[(String, PartitionLocation)] = {
+    val locations = new Array[(String, PartitionLocation)](uniqueIds.length)
+    var i = 0
+    while (i < uniqueIds.length) {
+      val uniqueId = uniqueIds(i)
+      locations(i) = uniqueId -> getReplicaLocation(shuffleKey, uniqueId)
+      i += 1
+    }
+    locations
+  }
+
   def removeShuffle(shuffleKey: String): Unit = {
     primaryPartitionLocations.remove(shuffleKey)
     replicaPartitionLocations.remove(shuffleKey)
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index 627104e9a..12783265f 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -417,20 +417,21 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
       return
     }
 
-    val partitionIdToLocations = pushMergedData.partitionUniqueIds.map { id =>
+    val partitionIdToLocations =
       if (isPrimary) {
-        id -> partitionLocationInfo.getPrimaryLocation(shuffleKey, id)
+        partitionLocationInfo.getPrimaryLocations(shuffleKey, 
pushMergedData.partitionUniqueIds)
       } else {
-        id -> partitionLocationInfo.getReplicaLocation(shuffleKey, id)
+        partitionLocationInfo.getReplicaLocations(shuffleKey, 
pushMergedData.partitionUniqueIds)
       }
-    }
 
     // Fetch real batchId from body will add more cost and no meaning for 
replicate.
     val doReplicate =
       partitionIdToLocations.head._2 != null && 
partitionIdToLocations.head._2.hasPeer && isPrimary
 
     // find FileWriters responsible for the data
-    partitionIdToLocations.foreach { case (id, loc) =>
+    var index = 0
+    while (index < partitionIdToLocations.length) {
+      val (id, loc) = partitionIdToLocations(index)
       if (loc == null) {
         val (mapId, attemptId) = getMapAttempt(body)
         // MapperAttempts for a shuffle exists after any CommitFiles request 
succeeds.
@@ -468,6 +469,7 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
         }
         return
       }
+      index += 1
     }
 
     // During worker shutdown, worker will return HARD_SPLIT for all existed 
partition.
@@ -477,11 +479,9 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
       return
     }
 
-    val fileWriters =
-      
partitionIdToLocations.map(_._2).map(_.asInstanceOf[WorkingPartition].getFileWriter)
-    val fileWriterWithException = fileWriters.find(_.getException != null)
-    if (fileWriterWithException.nonEmpty) {
-      val exception = fileWriterWithException.get.getException
+    val (fileWriters, exceptionFileWriterIndexOpt) = 
getFileWriters(partitionIdToLocations)
+    if (exceptionFileWriterIndexOpt.isDefined) {
+      val fileWriterWithException = 
fileWriters(exceptionFileWriterIndexOpt.get)
       val cause =
         if (isPrimary) {
           StatusCode.PUSH_DATA_WRITE_FAIL_PRIMARY
@@ -490,7 +490,7 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
         }
       logError(
         s"While handling PushMergedData, throw $cause, fileWriter 
$fileWriterWithException has exception.",
-        exception)
+        fileWriterWithException.getException)
       workerSource.incCounter(WorkerSource.WriteDataFailCount)
       callbackWithTimer.onFailure(new CelebornIOException(cause))
       return
@@ -617,7 +617,7 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
       }
     }
 
-    var index = 0
+    index = 0
     var fileWriter: FileWriter = null
     var alreadyClosed = false
     while (index < fileWriters.length) {
@@ -656,6 +656,28 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
     }
   }
 
+  /**
+   * returns an array of FileWriters from partition locations along with an 
optional index for any FileWriter that
+   * encountered an exception.
+   */
+  private def getFileWriters(
+      partitionIdToLocations: Array[(String, PartitionLocation)])
+      : (Array[FileWriter], Option[Int]) = {
+    val fileWriters = new Array[FileWriter](partitionIdToLocations.length)
+    var i = 0
+    var exceptionFileWriterIndex: Option[Int] = None
+    while (i < partitionIdToLocations.length) {
+      val (_, workingPartition) = partitionIdToLocations(i)
+      val fileWriter = 
workingPartition.asInstanceOf[WorkingPartition].getFileWriter
+      if (fileWriter.getException != null) {
+        exceptionFileWriterIndex = Some(i)
+      }
+      fileWriters(i) = fileWriter
+      i += 1
+    }
+    (fileWriters, exceptionFileWriterIndex)
+  }
+
   private def getMapAttempt(body: ByteBuf): (Int, Int) = {
     // header: mapId attemptId batchId compressedTotalSize
     val header = new Array[Byte](8)

Reply via email to