This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new a42ec85a6 [CELEBORN-943][PERF] Pre-create CelebornInputStreams in 
CelebornShuffleReader
a42ec85a6 is described below

commit a42ec85a6e5099ce2d54bcf01d0eb0e63b26df6d
Author: zky.zhoukeyong <[email protected]>
AuthorDate: Mon Sep 4 21:46:11 2023 +0800

    [CELEBORN-943][PERF] Pre-create CelebornInputStreams in 
CelebornShuffleReader
    
    ### What changes were proposed in this pull request?
    This PR fixes performance degradation when Spark's coalescePartitions takes 
effect caused
    by RPC latency.
    
    ### Why are the changes needed?
    I encountered a performance degradation when testing  tpcds 10T q10:
    ||Time|
    |---|---|
    |ESS|14s|
    |Celeborn| 24s|
    
    After digging into it I found out that q10 triggers partition coalescence:
    
![image](https://github.com/apache/incubator-celeborn/assets/948245/0b4745da-8d57-4661-a35d-683d97f56e1d)
    
    As I configured `spark.sql.adaptive.coalescePartitions.initialPartitionNum` 
to 1000, `CelebornShuffleReader`
    will call `shuffleClient.readPartition` sequentially 1000 times, causing 
the delay.
    
    This PR optimizes by calling `shuffleClient.readPartition` in parallel. 
After this PR q10 time becomes 14s.
    
    ### Does this PR introduce _any_ user-facing change?
    No, but introduced a new client side configuration 
`celeborn.client.streamCreatorPool.threads`
    which defaults to 32.
    
    ### How was this patch tested?
    TPCDS 1T and passes GA.
    
    Closes #1876 from waitinfuture/943.
    
    Lead-authored-by: zky.zhoukeyong <[email protected]>
    Co-authored-by: Keyong Zhou <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../shuffle/celeborn/CelebornShuffleReader.scala   | 66 +++++++++++++++++++---
 .../shuffle/celeborn/CelebornShuffleReader.scala   | 64 +++++++++++++++++++--
 .../celeborn/client/read/DfsPartitionReader.java   |  6 +-
 .../celeborn/client/read/LocalPartitionReader.java | 20 +++++--
 .../org/apache/celeborn/common/CelebornConf.scala  |  9 +++
 docs/configuration/client.md                       |  1 +
 6 files changed, 146 insertions(+), 20 deletions(-)

diff --git 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 8d5c99e0a..337143a50 100644
--- 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -17,9 +17,14 @@
 
 package org.apache.spark.shuffle.celeborn
 
+import java.io.IOException
+import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor}
+import java.util.concurrent.atomic.AtomicReference
+
 import org.apache.spark.{InterruptibleIterator, TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.shuffle.ShuffleReader
+import 
org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool
 import org.apache.spark.util.CompletionIterator
 import org.apache.spark.util.collection.ExternalSorter
 
@@ -27,6 +32,8 @@ import org.apache.celeborn.client.ShuffleClient
 import org.apache.celeborn.client.read.CelebornInputStream
 import org.apache.celeborn.client.read.MetricsCallback
 import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.exception.CelebornIOException
+import org.apache.celeborn.common.util.ThreadUtils
 
 class CelebornShuffleReader[K, C](
     handle: CelebornShuffleHandle[K, _, C],
@@ -39,13 +46,15 @@ class CelebornShuffleReader[K, C](
   extends ShuffleReader[K, C] with Logging {
 
   private val dep = handle.dependency
-  private val essShuffleClient = ShuffleClient.get(
+  private val shuffleClient = ShuffleClient.get(
     handle.appUniqueId,
     handle.lifecycleManagerHost,
     handle.lifecycleManagerPort,
     conf,
     handle.userIdentifier)
 
+  private val exceptionRef = new AtomicReference[IOException]
+
   override def read(): Iterator[Product2[K, C]] = {
 
     val serializerInstance = dep.serializer.newInstance()
@@ -60,15 +69,54 @@ class CelebornShuffleReader[K, C](
         readMetrics.incFetchWaitTime(time)
     }
 
+    if (streamCreatorPool == null) {
+      CelebornShuffleReader.synchronized {
+        if (streamCreatorPool == null) {
+          streamCreatorPool = ThreadUtils.newDaemonCachedThreadPool(
+            "celeborn-create-stream-thread",
+            conf.readStreamCreatorPoolThreads,
+            60);
+        }
+      }
+    }
+
+    val streams = new ConcurrentHashMap[Integer, CelebornInputStream]()
+    (startPartition until endPartition).map(partitionId => {
+      streamCreatorPool.submit(new Runnable {
+        override def run(): Unit = {
+          if (exceptionRef.get() == null) {
+            try {
+              val inputStream = shuffleClient.readPartition(
+                handle.shuffleId,
+                partitionId,
+                context.attemptNumber(),
+                startMapIndex,
+                endMapIndex)
+              streams.put(partitionId, inputStream)
+            } catch {
+              case e: IOException =>
+                logInfo("Exception caught when readPartition!")
+                exceptionRef.compareAndSet(null, e)
+              case e: Throwable =>
+                logInfo("Non IOException caught when readPartition!", e)
+                exceptionRef.compareAndSet(null, new CelebornIOException(e))
+            }
+          }
+        }
+      })
+    })
+
     val recordIter = (startPartition until 
endPartition).iterator.map(partitionId => {
       if (handle.numMaps > 0) {
         val start = System.currentTimeMillis()
-        val inputStream = essShuffleClient.readPartition(
-          handle.shuffleId,
-          partitionId,
-          context.attemptNumber(),
-          startMapIndex,
-          endMapIndex)
+        var inputStream: CelebornInputStream = streams.get(partitionId)
+        while (inputStream == null) {
+          if (exceptionRef.get() != null) {
+            throw exceptionRef.get()
+          }
+          Thread.sleep(50)
+          inputStream = streams.get(partitionId)
+        }
         metricsCallback.incReadTime(System.currentTimeMillis() - start)
         inputStream.setCallback(metricsCallback)
         // ensure inputStream is closed when task completes
@@ -135,3 +183,7 @@ class CelebornShuffleReader[K, C](
     }
   }
 }
+
+object CelebornShuffleReader {
+  var streamCreatorPool: ThreadPoolExecutor = null
+}
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index eadb655c9..f07ed4989 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -17,16 +17,23 @@
 
 package org.apache.spark.shuffle.celeborn
 
+import java.io.IOException
+import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor}
+import java.util.concurrent.atomic.AtomicReference
+
 import org.apache.spark.{InterruptibleIterator, ShuffleDependency, TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.serializer.SerializerInstance
 import org.apache.spark.shuffle.{ShuffleReader, ShuffleReadMetricsReporter}
+import 
org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool
 import org.apache.spark.util.CompletionIterator
 import org.apache.spark.util.collection.ExternalSorter
 
 import org.apache.celeborn.client.ShuffleClient
 import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
 import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.exception.CelebornIOException
+import org.apache.celeborn.common.util.ThreadUtils
 
 class CelebornShuffleReader[K, C](
     handle: CelebornShuffleHandle[K, _, C],
@@ -47,6 +54,8 @@ class CelebornShuffleReader[K, C](
     conf,
     handle.userIdentifier)
 
+  private val exceptionRef = new AtomicReference[IOException]
+
   override def read(): Iterator[Product2[K, C]] = {
 
     val serializerInstance = newSerializerInstance(dep)
@@ -62,15 +71,54 @@ class CelebornShuffleReader[K, C](
         metrics.incFetchWaitTime(time)
     }
 
+    if (streamCreatorPool == null) {
+      CelebornShuffleReader.synchronized {
+        if (streamCreatorPool == null) {
+          streamCreatorPool = ThreadUtils.newDaemonCachedThreadPool(
+            "celeborn-create-stream-thread",
+            conf.readStreamCreatorPoolThreads,
+            60);
+        }
+      }
+    }
+
+    val streams = new ConcurrentHashMap[Integer, CelebornInputStream]()
+    (startPartition until endPartition).map(partitionId => {
+      streamCreatorPool.submit(new Runnable {
+        override def run(): Unit = {
+          if (exceptionRef.get() == null) {
+            try {
+              val inputStream = shuffleClient.readPartition(
+                handle.shuffleId,
+                partitionId,
+                context.attemptNumber(),
+                startMapIndex,
+                endMapIndex)
+              streams.put(partitionId, inputStream)
+            } catch {
+              case e: IOException =>
+                logInfo("Exception caught when readPartition!", e)
+                exceptionRef.compareAndSet(null, e)
+              case e: Throwable =>
+                logInfo("Non IOException caught when readPartition!", e)
+                exceptionRef.compareAndSet(null, new CelebornIOException(e))
+            }
+          }
+        }
+      })
+    })
+
     val recordIter = (startPartition until 
endPartition).iterator.map(partitionId => {
       if (handle.numMappers > 0) {
         val start = System.currentTimeMillis()
-        val inputStream = shuffleClient.readPartition(
-          handle.shuffleId,
-          partitionId,
-          context.attemptNumber(),
-          startMapIndex,
-          endMapIndex)
+        var inputStream: CelebornInputStream = streams.get(partitionId)
+        while (inputStream == null) {
+          if (exceptionRef.get() != null) {
+            throw exceptionRef.get()
+          }
+          Thread.sleep(50)
+          inputStream = streams.get(partitionId)
+        }
         metricsCallback.incReadTime(System.currentTimeMillis() - start)
         inputStream.setCallback(metricsCallback)
         // ensure inputStream is closed when task completes
@@ -148,3 +196,7 @@ class CelebornShuffleReader[K, C](
   }
 
 }
+
+object CelebornShuffleReader {
+  var streamCreatorPool: ThreadPoolExecutor = null
+}
diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java 
b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
index 68f6308b3..ec930b8d3 100644
--- 
a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
+++ 
b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
@@ -53,6 +53,7 @@ public class DfsPartitionReader implements PartitionReader {
   private final AtomicReference<IOException> exception = new 
AtomicReference<>();
   private volatile boolean closed = false;
   private Thread fetchThread;
+  private boolean fetchThreadStarted;
   private FSDataInputStream hdfsInputStream;
   private int numChunks = 0;
   private int returnedChunks = 0;
@@ -168,7 +169,6 @@ public class DfsPartitionReader implements PartitionReader {
               logger.error("thread {} failed with exception {}", t, e);
             }
           });
-      fetchThread.start();
       logger.debug("Start dfs read on location {}", location);
       ShuffleClient.incrementTotalReadCounter();
     }
@@ -218,6 +218,10 @@ public class DfsPartitionReader implements PartitionReader 
{
   @Override
   public ByteBuf next() throws IOException, InterruptedException {
     ByteBuf chunk = null;
+    if (!fetchThreadStarted) {
+      fetchThreadStarted = true;
+      fetchThread.start();
+    }
     try {
       while (chunk == null) {
         checkException();
diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
 
b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
index 1168c8d2c..486f515c5 100644
--- 
a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
+++ 
b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
@@ -58,7 +58,9 @@ public class LocalPartitionReader implements PartitionReader {
   private final int numChunks;
   private int returnedChunks = 0;
   private int chunkIndex = 0;
-  private final FileChannel shuffleChannel;
+  private String fullPath;
+  private boolean mapRangeRead = false;
+  private FileChannel shuffleChannel;
   private List<Long> chunkOffsets;
   private AtomicBoolean pendingFetchTask = new AtomicBoolean(false);
 
@@ -111,10 +113,8 @@ public class LocalPartitionReader implements 
PartitionReader {
 
     chunkOffsets = new ArrayList<>(streamHandle.getChunkOffsetsList());
     numChunks = streamHandle.getNumChunks();
-    shuffleChannel = 
FileChannelUtils.openReadableFileChannel(streamHandle.getFullPath());
-    if (endMapIndex != Integer.MAX_VALUE) {
-      shuffleChannel.position(chunkOffsets.get(0));
-    }
+    fullPath = streamHandle.getFullPath();
+    mapRangeRead = endMapIndex != Integer.MAX_VALUE;
 
     logger.debug(
         "Local partition reader {} offsets:{}",
@@ -126,6 +126,12 @@ public class LocalPartitionReader implements 
PartitionReader {
 
   private void doFetchChunks(int chunkIndex, int toFetch) {
     try {
+      if (shuffleChannel == null) {
+        shuffleChannel = FileChannelUtils.openReadableFileChannel(fullPath);
+        if (mapRangeRead) {
+          shuffleChannel.position(chunkOffsets.get(0));
+        }
+      }
       for (int i = 0; i < toFetch; i++) {
         long offset = chunkOffsets.get(chunkIndex + i);
         long length = chunkOffsets.get(chunkIndex + i + 1) - offset;
@@ -219,7 +225,9 @@ public class LocalPartitionReader implements 
PartitionReader {
       results.clear();
     }
     try {
-      shuffleChannel.close();
+      if (shuffleChannel != null) {
+        shuffleChannel.close();
+      }
     } catch (IOException e) {
       logger.warn("Close local shuffle file failed.", e);
     }
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index c23e1b1ac..48dafc219 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -825,6 +825,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable 
with Logging with Se
     get(CLIENT_BATCH_HANDLED_RELEASE_PARTITION_INTERVAL)
   def enableReadLocalShuffleFile: Boolean = get(READ_LOCAL_SHUFFLE_FILE)
   def readLocalShuffleThreads: Int = get(READ_LOCAL_SHUFFLE_THREADS)
+  def readStreamCreatorPoolThreads: Int = get(READ_STREAM_CREATOR_POOL_THREADS)
 
   // //////////////////////////////////////////////////////
   //                       Worker                        //
@@ -3818,6 +3819,14 @@ object CelebornConf extends Logging {
       .intConf
       .createWithDefault(4)
 
+  val READ_STREAM_CREATOR_POOL_THREADS: ConfigEntry[Int] =
+    buildConf("celeborn.client.eagerlyCreateInputStream.threads")
+      .categories("client")
+      .version("0.3.1")
+      .doc("Threads count for streamCreatorPool in CelebornShuffleReader.")
+      .intConf
+      .createWithDefault(32)
+
   val CLIENT_SHUFFLE_MAPPARTITION_SPLIT_ENABLED: ConfigEntry[Boolean] =
     buildConf("celeborn.client.shuffle.mapPartition.split.enabled")
       .categories("client")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 8e16b7fbe..f3d8dca02 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -22,6 +22,7 @@ license: |
 | celeborn.client.application.heartbeatInterval | 10s | Interval for client to 
send heartbeat message to master. | 0.3.0 | 
 | celeborn.client.closeIdleConnections | true | Whether client will close idle 
connections. | 0.3.0 | 
 | celeborn.client.commitFiles.ignoreExcludedWorker | false | When true, 
LifecycleManager will skip workers which are in the excluded list. | 0.3.0 | 
+| celeborn.client.eagerlyCreateInputStream.threads | 32 | Threads count for 
streamCreatorPool in CelebornShuffleReader. | 0.3.1 | 
 | celeborn.client.excludePeerWorkerOnFailure.enabled | true | When true, 
Celeborn will exclude partition's peer worker on failure when push data to 
replica failed. | 0.3.0 | 
 | celeborn.client.excludedWorker.expireTimeout | 180s | Timeout time for 
LifecycleManager to clear reserved excluded worker. Default to be 1.5 * 
`celeborn.master.heartbeat.worker.timeout`to cover worker heartbeat timeout 
check period | 0.3.0 | 
 | celeborn.client.fetch.dfsReadChunkSize | 8m | Max chunk size for 
DfsPartitionReader. | 0.3.1 | 

Reply via email to