This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new a42ec85a6 [CELEBORN-943][PERF] Pre-create CelebornInputStreams in
CelebornShuffleReader
a42ec85a6 is described below
commit a42ec85a6e5099ce2d54bcf01d0eb0e63b26df6d
Author: zky.zhoukeyong <[email protected]>
AuthorDate: Mon Sep 4 21:46:11 2023 +0800
[CELEBORN-943][PERF] Pre-create CelebornInputStreams in
CelebornShuffleReader
### What changes were proposed in this pull request?
This PR fixes performance degradation when Spark's coalescePartitions takes
effect caused
by RPC latency.
### Why are the changes needed?
I encountered a performance degradation when testing tpcds 10T q10:
||Time|
|---|---|
|ESS|14s|
|Celeborn| 24s|
After digging into it I found out that q10 triggers partition coalescence:

As I configured `spark.sql.adaptive.coalescePartitions.initialPartitionNum`
to 1000, `CelebornShuffleReader`
will call `shuffleClient.readPartition` sequentially 1000 times, causing
the delay.
This PR optimizes by calling `shuffleClient.readPartition` in parallel.
After this PR q10 time becomes 14s.
### Does this PR introduce _any_ user-facing change?
No, but introduced a new client side configuration
`celeborn.client.streamCreatorPool.threads`
which defaults to 32.
### How was this patch tested?
TPCDS 1T and passes GA.
Closes #1876 from waitinfuture/943.
Lead-authored-by: zky.zhoukeyong <[email protected]>
Co-authored-by: Keyong Zhou <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
---
.../shuffle/celeborn/CelebornShuffleReader.scala | 66 +++++++++++++++++++---
.../shuffle/celeborn/CelebornShuffleReader.scala | 64 +++++++++++++++++++--
.../celeborn/client/read/DfsPartitionReader.java | 6 +-
.../celeborn/client/read/LocalPartitionReader.java | 20 +++++--
.../org/apache/celeborn/common/CelebornConf.scala | 9 +++
docs/configuration/client.md | 1 +
6 files changed, 146 insertions(+), 20 deletions(-)
diff --git
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 8d5c99e0a..337143a50 100644
---
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -17,9 +17,14 @@
package org.apache.spark.shuffle.celeborn
+import java.io.IOException
+import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor}
+import java.util.concurrent.atomic.AtomicReference
+
import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.ShuffleReader
+import
org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
@@ -27,6 +32,8 @@ import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.read.CelebornInputStream
import org.apache.celeborn.client.read.MetricsCallback
import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.exception.CelebornIOException
+import org.apache.celeborn.common.util.ThreadUtils
class CelebornShuffleReader[K, C](
handle: CelebornShuffleHandle[K, _, C],
@@ -39,13 +46,15 @@ class CelebornShuffleReader[K, C](
extends ShuffleReader[K, C] with Logging {
private val dep = handle.dependency
- private val essShuffleClient = ShuffleClient.get(
+ private val shuffleClient = ShuffleClient.get(
handle.appUniqueId,
handle.lifecycleManagerHost,
handle.lifecycleManagerPort,
conf,
handle.userIdentifier)
+ private val exceptionRef = new AtomicReference[IOException]
+
override def read(): Iterator[Product2[K, C]] = {
val serializerInstance = dep.serializer.newInstance()
@@ -60,15 +69,54 @@ class CelebornShuffleReader[K, C](
readMetrics.incFetchWaitTime(time)
}
+ if (streamCreatorPool == null) {
+ CelebornShuffleReader.synchronized {
+ if (streamCreatorPool == null) {
+ streamCreatorPool = ThreadUtils.newDaemonCachedThreadPool(
+ "celeborn-create-stream-thread",
+ conf.readStreamCreatorPoolThreads,
+ 60);
+ }
+ }
+ }
+
+ val streams = new ConcurrentHashMap[Integer, CelebornInputStream]()
+ (startPartition until endPartition).map(partitionId => {
+ streamCreatorPool.submit(new Runnable {
+ override def run(): Unit = {
+ if (exceptionRef.get() == null) {
+ try {
+ val inputStream = shuffleClient.readPartition(
+ handle.shuffleId,
+ partitionId,
+ context.attemptNumber(),
+ startMapIndex,
+ endMapIndex)
+ streams.put(partitionId, inputStream)
+ } catch {
+ case e: IOException =>
+ logInfo("Exception caught when readPartition!")
+ exceptionRef.compareAndSet(null, e)
+ case e: Throwable =>
+ logInfo("Non IOException caught when readPartition!", e)
+ exceptionRef.compareAndSet(null, new CelebornIOException(e))
+ }
+ }
+ }
+ })
+ })
+
val recordIter = (startPartition until
endPartition).iterator.map(partitionId => {
if (handle.numMaps > 0) {
val start = System.currentTimeMillis()
- val inputStream = essShuffleClient.readPartition(
- handle.shuffleId,
- partitionId,
- context.attemptNumber(),
- startMapIndex,
- endMapIndex)
+ var inputStream: CelebornInputStream = streams.get(partitionId)
+ while (inputStream == null) {
+ if (exceptionRef.get() != null) {
+ throw exceptionRef.get()
+ }
+ Thread.sleep(50)
+ inputStream = streams.get(partitionId)
+ }
metricsCallback.incReadTime(System.currentTimeMillis() - start)
inputStream.setCallback(metricsCallback)
// ensure inputStream is closed when task completes
@@ -135,3 +183,7 @@ class CelebornShuffleReader[K, C](
}
}
}
+
+object CelebornShuffleReader {
+ var streamCreatorPool: ThreadPoolExecutor = null
+}
diff --git
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index eadb655c9..f07ed4989 100644
---
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -17,16 +17,23 @@
package org.apache.spark.shuffle.celeborn
+import java.io.IOException
+import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor}
+import java.util.concurrent.atomic.AtomicReference
+
import org.apache.spark.{InterruptibleIterator, ShuffleDependency, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.shuffle.{ShuffleReader, ShuffleReadMetricsReporter}
+import
org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.exception.CelebornIOException
+import org.apache.celeborn.common.util.ThreadUtils
class CelebornShuffleReader[K, C](
handle: CelebornShuffleHandle[K, _, C],
@@ -47,6 +54,8 @@ class CelebornShuffleReader[K, C](
conf,
handle.userIdentifier)
+ private val exceptionRef = new AtomicReference[IOException]
+
override def read(): Iterator[Product2[K, C]] = {
val serializerInstance = newSerializerInstance(dep)
@@ -62,15 +71,54 @@ class CelebornShuffleReader[K, C](
metrics.incFetchWaitTime(time)
}
+ if (streamCreatorPool == null) {
+ CelebornShuffleReader.synchronized {
+ if (streamCreatorPool == null) {
+ streamCreatorPool = ThreadUtils.newDaemonCachedThreadPool(
+ "celeborn-create-stream-thread",
+ conf.readStreamCreatorPoolThreads,
+ 60);
+ }
+ }
+ }
+
+ val streams = new ConcurrentHashMap[Integer, CelebornInputStream]()
+ (startPartition until endPartition).map(partitionId => {
+ streamCreatorPool.submit(new Runnable {
+ override def run(): Unit = {
+ if (exceptionRef.get() == null) {
+ try {
+ val inputStream = shuffleClient.readPartition(
+ handle.shuffleId,
+ partitionId,
+ context.attemptNumber(),
+ startMapIndex,
+ endMapIndex)
+ streams.put(partitionId, inputStream)
+ } catch {
+ case e: IOException =>
+ logInfo("Exception caught when readPartition!", e)
+ exceptionRef.compareAndSet(null, e)
+ case e: Throwable =>
+ logInfo("Non IOException caught when readPartition!", e)
+ exceptionRef.compareAndSet(null, new CelebornIOException(e))
+ }
+ }
+ }
+ })
+ })
+
val recordIter = (startPartition until
endPartition).iterator.map(partitionId => {
if (handle.numMappers > 0) {
val start = System.currentTimeMillis()
- val inputStream = shuffleClient.readPartition(
- handle.shuffleId,
- partitionId,
- context.attemptNumber(),
- startMapIndex,
- endMapIndex)
+ var inputStream: CelebornInputStream = streams.get(partitionId)
+ while (inputStream == null) {
+ if (exceptionRef.get() != null) {
+ throw exceptionRef.get()
+ }
+ Thread.sleep(50)
+ inputStream = streams.get(partitionId)
+ }
metricsCallback.incReadTime(System.currentTimeMillis() - start)
inputStream.setCallback(metricsCallback)
// ensure inputStream is closed when task completes
@@ -148,3 +196,7 @@ class CelebornShuffleReader[K, C](
}
}
+
+object CelebornShuffleReader {
+ var streamCreatorPool: ThreadPoolExecutor = null
+}
diff --git
a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
index 68f6308b3..ec930b8d3 100644
---
a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
+++
b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
@@ -53,6 +53,7 @@ public class DfsPartitionReader implements PartitionReader {
private final AtomicReference<IOException> exception = new
AtomicReference<>();
private volatile boolean closed = false;
private Thread fetchThread;
+ private boolean fetchThreadStarted;
private FSDataInputStream hdfsInputStream;
private int numChunks = 0;
private int returnedChunks = 0;
@@ -168,7 +169,6 @@ public class DfsPartitionReader implements PartitionReader {
logger.error("thread {} failed with exception {}", t, e);
}
});
- fetchThread.start();
logger.debug("Start dfs read on location {}", location);
ShuffleClient.incrementTotalReadCounter();
}
@@ -218,6 +218,10 @@ public class DfsPartitionReader implements PartitionReader
{
@Override
public ByteBuf next() throws IOException, InterruptedException {
ByteBuf chunk = null;
+ if (!fetchThreadStarted) {
+ fetchThreadStarted = true;
+ fetchThread.start();
+ }
try {
while (chunk == null) {
checkException();
diff --git
a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
index 1168c8d2c..486f515c5 100644
---
a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
+++
b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
@@ -58,7 +58,9 @@ public class LocalPartitionReader implements PartitionReader {
private final int numChunks;
private int returnedChunks = 0;
private int chunkIndex = 0;
- private final FileChannel shuffleChannel;
+ private String fullPath;
+ private boolean mapRangeRead = false;
+ private FileChannel shuffleChannel;
private List<Long> chunkOffsets;
private AtomicBoolean pendingFetchTask = new AtomicBoolean(false);
@@ -111,10 +113,8 @@ public class LocalPartitionReader implements
PartitionReader {
chunkOffsets = new ArrayList<>(streamHandle.getChunkOffsetsList());
numChunks = streamHandle.getNumChunks();
- shuffleChannel =
FileChannelUtils.openReadableFileChannel(streamHandle.getFullPath());
- if (endMapIndex != Integer.MAX_VALUE) {
- shuffleChannel.position(chunkOffsets.get(0));
- }
+ fullPath = streamHandle.getFullPath();
+ mapRangeRead = endMapIndex != Integer.MAX_VALUE;
logger.debug(
"Local partition reader {} offsets:{}",
@@ -126,6 +126,12 @@ public class LocalPartitionReader implements
PartitionReader {
private void doFetchChunks(int chunkIndex, int toFetch) {
try {
+ if (shuffleChannel == null) {
+ shuffleChannel = FileChannelUtils.openReadableFileChannel(fullPath);
+ if (mapRangeRead) {
+ shuffleChannel.position(chunkOffsets.get(0));
+ }
+ }
for (int i = 0; i < toFetch; i++) {
long offset = chunkOffsets.get(chunkIndex + i);
long length = chunkOffsets.get(chunkIndex + i + 1) - offset;
@@ -219,7 +225,9 @@ public class LocalPartitionReader implements
PartitionReader {
results.clear();
}
try {
- shuffleChannel.close();
+ if (shuffleChannel != null) {
+ shuffleChannel.close();
+ }
} catch (IOException e) {
logger.warn("Close local shuffle file failed.", e);
}
diff --git
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index c23e1b1ac..48dafc219 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -825,6 +825,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable
with Logging with Se
get(CLIENT_BATCH_HANDLED_RELEASE_PARTITION_INTERVAL)
def enableReadLocalShuffleFile: Boolean = get(READ_LOCAL_SHUFFLE_FILE)
def readLocalShuffleThreads: Int = get(READ_LOCAL_SHUFFLE_THREADS)
+ def readStreamCreatorPoolThreads: Int = get(READ_STREAM_CREATOR_POOL_THREADS)
// //////////////////////////////////////////////////////
// Worker //
@@ -3818,6 +3819,14 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(4)
+ val READ_STREAM_CREATOR_POOL_THREADS: ConfigEntry[Int] =
+ buildConf("celeborn.client.eagerlyCreateInputStream.threads")
+ .categories("client")
+ .version("0.3.1")
+ .doc("Threads count for streamCreatorPool in CelebornShuffleReader.")
+ .intConf
+ .createWithDefault(32)
+
val CLIENT_SHUFFLE_MAPPARTITION_SPLIT_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.client.shuffle.mapPartition.split.enabled")
.categories("client")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 8e16b7fbe..f3d8dca02 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -22,6 +22,7 @@ license: |
| celeborn.client.application.heartbeatInterval | 10s | Interval for client to
send heartbeat message to master. | 0.3.0 |
| celeborn.client.closeIdleConnections | true | Whether client will close idle
connections. | 0.3.0 |
| celeborn.client.commitFiles.ignoreExcludedWorker | false | When true,
LifecycleManager will skip workers which are in the excluded list. | 0.3.0 |
+| celeborn.client.eagerlyCreateInputStream.threads | 32 | Threads count for
streamCreatorPool in CelebornShuffleReader. | 0.3.1 |
| celeborn.client.excludePeerWorkerOnFailure.enabled | true | When true,
Celeborn will exclude partition's peer worker on failure when push data to
replica failed. | 0.3.0 |
| celeborn.client.excludedWorker.expireTimeout | 180s | Timeout time for
LifecycleManager to clear reserved excluded worker. Default to be 1.5 *
`celeborn.master.heartbeat.worker.timeout`to cover worker heartbeat timeout
check period | 0.3.0 |
| celeborn.client.fetch.dfsReadChunkSize | 8m | Max chunk size for
DfsPartitionReader. | 0.3.1 |