This is an automated email from the ASF dual-hosted git repository.

feiwang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new a7e638706 [CELEBORN-2004] Filter empty partition before 
createIntputStream
a7e638706 is described below

commit a7e638706b697101d4f039ffd789c8a0eeee6889
Author: Fei Wang <[email protected]>
AuthorDate: Tue May 20 19:37:11 2025 -0700

    [CELEBORN-2004] Filter empty partition before createIntputStream
    
    ### What changes were proposed in this pull request?
    Filter empty partition from partitionFileGroup before createIntputStream.
    
    ### Why are the changes needed?
    Avoid creating the IntputStream for the empty partitions which might be a 
lot when partition num is large and data is small.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    UT and cluster test.
    
    Closes #3266 from zaynt4606/clb2004.
    
    Lead-authored-by: Fei Wang <[email protected]>
    Co-authored-by: zhengtao <[email protected]>
    Signed-off-by: Wang, Fei <[email protected]>
---
 .../shuffle/celeborn/CelebornShuffleReader.scala   | 94 ++++++++++++----------
 1 file changed, 51 insertions(+), 43 deletions(-)

diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 958e08196..19c703319 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -176,7 +176,47 @@ class CelebornShuffleReader[K, C](
     val splitSkewPartitionWithoutMapRange =
       ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, 
endMapIndex)
 
-    (startPartition until endPartition).foreach { partitionId =>
+    // filter empty partition
+    val partitionIdList = List.range(startPartition, endPartition).filter(p =>
+      fileGroups.partitionGroups.containsKey(p))
+
+    def makeOpenStreamList(locations: JSet[PartitionLocation]): Unit = {
+      locations.asScala.foreach { location =>
+        partCnt += 1
+        val hostPort = location.hostAndFetchPort
+        if (!workerRequestMap.containsKey(hostPort)) {
+          try {
+            val client = shuffleClient.getDataClientFactory().createClient(
+              location.getHost,
+              location.getFetchPort)
+            val pbOpenStreamList = PbOpenStreamList.newBuilder()
+            pbOpenStreamList.setShuffleKey(shuffleKey)
+            workerRequestMap.put(
+              hostPort,
+              (client, new JArrayList[PartitionLocation], pbOpenStreamList))
+          } catch {
+            case ex: Exception =>
+              shuffleClient.excludeFailedFetchLocation(hostPort, ex)
+              logWarning(
+                s"Failed to create client for $shuffleKey-${location.getId} 
from host: ${hostPort}. " +
+                  s"Shuffle reader will try its replica if exists.")
+          }
+        }
+        workerRequestMap.get(hostPort) match {
+          case (_, locArr, pbOpenStreamListBuilder) =>
+            locArr.add(location)
+            pbOpenStreamListBuilder.addFileName(location.getFileName)
+              .addStartIndex(startMapIndex)
+              .addEndIndex(endMapIndex)
+            pbOpenStreamListBuilder.addReadLocalShuffle(
+              localFetchEnabled && location.getHost.equals(localHostAddress))
+          case _ =>
+            logDebug(s"Empty client for host ${hostPort}")
+        }
+      }
+    }
+
+    partitionIdList.foreach { partitionId =>
       if (fileGroups.partitionGroups.containsKey(partitionId)) {
         var locations = fileGroups.partitionGroups.get(partitionId)
         if (splitSkewPartitionWithoutMapRange) {
@@ -194,40 +234,7 @@ class CelebornShuffleReader[K, C](
           locations = filterLocations.asJava
           partitionId2PartitionLocations.put(partitionId, locations)
         }
-
-        locations.asScala.foreach { location =>
-          partCnt += 1
-          val hostPort = location.hostAndFetchPort
-          if (!workerRequestMap.containsKey(hostPort)) {
-            try {
-              val client = shuffleClient.getDataClientFactory().createClient(
-                location.getHost,
-                location.getFetchPort)
-              val pbOpenStreamList = PbOpenStreamList.newBuilder()
-              pbOpenStreamList.setShuffleKey(shuffleKey)
-              workerRequestMap.put(
-                hostPort,
-                (client, new JArrayList[PartitionLocation], pbOpenStreamList))
-            } catch {
-              case ex: Exception =>
-                
shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort, ex)
-                logWarning(
-                  s"Failed to create client for $shuffleKey-$partitionId from 
host: ${location.hostAndFetchPort}. " +
-                    s"Shuffle reader will try its replica if exists.")
-            }
-          }
-          workerRequestMap.get(hostPort) match {
-            case (_, locArr, pbOpenStreamListBuilder) =>
-              locArr.add(location)
-              pbOpenStreamListBuilder.addFileName(location.getFileName)
-                .addStartIndex(startMapIndex)
-                .addEndIndex(endMapIndex)
-              pbOpenStreamListBuilder.addReadLocalShuffle(
-                localFetchEnabled && location.getHost.equals(localHostAddress))
-            case _ =>
-              logDebug(s"Empty client for host ${hostPort}")
-          }
-        }
+        makeOpenStreamList(locations)
       }
     }
 
@@ -321,17 +328,16 @@ class CelebornShuffleReader[K, C](
     }
 
     val inputStreamCreationWindow = conf.clientInputStreamCreationWindow
-    (startPartition until Math.min(
-      startPartition + inputStreamCreationWindow,
-      endPartition)).foreach(partitionId => {
+    (0 until Math.min(inputStreamCreationWindow, 
partitionIdList.size)).foreach(listIndex => {
       streamCreatorPool.submit(new Runnable {
         override def run(): Unit = {
-          createInputStream(partitionId)
+          createInputStream(partitionIdList(listIndex))
         }
       })
     })
 
-    val recordIter = (startPartition until 
endPartition).iterator.map(partitionId => {
+    var curIndex = 0
+    val recordIter = partitionIdList.iterator.map(partitionId => {
       if (handle.numMappers > 0) {
         val startFetchWait = System.nanoTime()
         var inputStream: CelebornInputStream = streams.get(partitionId)
@@ -361,16 +367,18 @@ class CelebornShuffleReader[K, C](
         context.addTaskCompletionListener[Unit](_ => inputStream.close())
 
         // Advance the input creation window
-        if (partitionId + inputStreamCreationWindow < endPartition) {
+        if (curIndex + inputStreamCreationWindow < partitionIdList.size) {
+          val nextPartitionId = partitionIdList(curIndex + 
inputStreamCreationWindow)
           streamCreatorPool.submit(new Runnable {
             override def run(): Unit = {
-              createInputStream(partitionId + inputStreamCreationWindow)
+              createInputStream(nextPartitionId)
             }
           })
         }
-
+        curIndex = curIndex + 1
         (partitionId, inputStream)
       } else {
+        curIndex = curIndex + 1
         (partitionId, CelebornInputStream.empty())
       }
     }).filter {

Reply via email to