This is an automated email from the ASF dual-hosted git repository.

kerwinzhang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 2ce8d6fd9 [CELEBORN-1102] Optimize the performance of 
getAllPrimaryLocationsWithMinEpoch
2ce8d6fd9 is described below

commit 2ce8d6fd95a7755462de2186b1f799ab0ac08427
Author: xiyu.zk <[email protected]>
AuthorDate: Tue Oct 31 20:37:17 2023 +0800

    [CELEBORN-1102] Optimize the performance of 
getAllPrimaryLocationsWithMinEpoch
    
    ### What changes were proposed in this pull request?
    Optimize the performance of getAllPrimaryLocationsWithMinEpoch
    
    ### Why are the changes needed?
    #### Before optimization:
    
![image](https://github.com/apache/incubator-celeborn/assets/107825064/0ccbf503-99b7-45db-a8bd-6539e854d011)
    
    #### After optimization:
    
![image](https://github.com/apache/incubator-celeborn/assets/107825064/0cb54276-a089-44dc-9b75-6649537515f2)
    
    ### Does this PR introduce _any_ user-facing change?
    
    ### How was this patch tested?
    
    Closes #2058 from kerwin-zk/issue-1102.
    
    Authored-by: xiyu.zk <[email protected]>
    Signed-off-by: xiyu.zk <[email protected]>
---
 .../apache/celeborn/client/LifecycleManager.scala  |  2 +-
 .../common/meta/ShufflePartitionLocationInfo.scala | 24 +++++++++++++++++-----
 2 files changed, 20 insertions(+), 6 deletions(-)

diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index ea1517a21..bd8527db3 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -319,7 +319,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
           val initialLocs = workerSnapshots(shuffleId)
             .values()
             .asScala
-            .flatMap(_.getAllPrimaryLocationsWithMinEpoch().asScala)
+            .flatMap(_.getAllPrimaryLocationsWithMinEpoch())
             .filter(p =>
               (partitionType == PartitionType.REDUCE && p.getEpoch == 0) || 
(partitionType == PartitionType.MAP && p.getId == partitionId))
             .toArray
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala
 
b/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala
index 1cf318a58..12e848830 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala
@@ -21,6 +21,7 @@ import java.util
 import java.util.concurrent.ConcurrentHashMap
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
 
 import org.apache.celeborn.common.protocol.PartitionLocation
 
@@ -29,7 +30,6 @@ class ShufflePartitionLocationInfo {
 
   private val primaryPartitionLocations = new PartitionInfo
   private val replicaPartitionLocations = new PartitionInfo
-  implicit val partitionOrdering: Ordering[PartitionLocation] = 
Ordering.by(_.getEpoch)
 
   def addPrimaryPartitions(primaryLocations: util.List[PartitionLocation]) = {
     addPartitions(primaryPartitionLocations, primaryLocations)
@@ -89,10 +89,24 @@ class ShufflePartitionLocationInfo {
     }
   }
 
-  def getAllPrimaryLocationsWithMinEpoch(): util.Set[PartitionLocation] = {
-    primaryPartitionLocations.values().asScala.map { partitionLocations =>
-      partitionLocations.asScala.min
-    }.toSet.asJava
+  def getAllPrimaryLocationsWithMinEpoch(): ArrayBuffer[PartitionLocation] = {
+    val set = new 
ArrayBuffer[PartitionLocation](primaryPartitionLocations.size())
+    val locationsIterator = primaryPartitionLocations.values().iterator()
+    while (locationsIterator.hasNext) {
+      val locationIterator = locationsIterator.next().iterator()
+      var min: PartitionLocation = null
+      if (locationIterator.hasNext) {
+        min = locationIterator.next();
+      }
+      while (locationIterator.hasNext) {
+        val next = locationIterator.next()
+        if (min.getEpoch > next.getEpoch) {
+          min = next;
+        }
+      }
+      set += min;
+    }
+    set
   }
 
   private def addPartitions(

Reply via email to