This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 7dcd25925 [CELEBORN-1671] CelebornShuffleReader will try replica if
create client failed
7dcd25925 is described below
commit 7dcd25925fc25da7bd61678a62bf018597f295a2
Author: mingji <[email protected]>
AuthorDate: Wed Nov 6 11:14:11 2024 +0800
[CELEBORN-1671] CelebornShuffleReader will try replica if create client
failed
### What changes were proposed in this pull request?
1. To bypass exceptions when creating clients failed in
CelebornShuffleReader in spark 3.
2. Client will try the location's replicas in reading locations.
### Why are the changes needed?
Allow clients to retry locations when creating clients failed.
### Does this PR introduce _any_ user-facing change?
NO.
### How was this patch tested?
Pass GA.
Closes #2854 from FMX/b1671.
Authored-by: mingji <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../shuffle/celeborn/CelebornShuffleReader.scala | 43 ++++++++++++++--------
.../org/apache/celeborn/client/ShuffleClient.java | 2 +
.../apache/celeborn/client/ShuffleClientImpl.java | 12 ++++++
.../celeborn/client/read/CelebornInputStream.java | 15 ++------
.../apache/celeborn/client/DummyShuffleClient.java | 3 ++
.../org/apache/celeborn/common/util/Utils.scala | 14 ++++++-
6 files changed, 60 insertions(+), 29 deletions(-)
diff --git
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 1b7b6f1dd..f6405a692 100644
---
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -119,23 +119,34 @@ class CelebornShuffleReader[K, C](
partCnt += 1
val hostPort = location.hostAndFetchPort
if (!workerRequestMap.containsKey(hostPort)) {
- val client = shuffleClient.getDataClientFactory().createClient(
- location.getHost,
- location.getFetchPort)
- val pbOpenStreamList = PbOpenStreamList.newBuilder()
- pbOpenStreamList.setShuffleKey(shuffleKey)
- workerRequestMap.put(
- hostPort,
- (client, new util.ArrayList[PartitionLocation],
pbOpenStreamList))
+ try {
+ val client = shuffleClient.getDataClientFactory().createClient(
+ location.getHost,
+ location.getFetchPort)
+ val pbOpenStreamList = PbOpenStreamList.newBuilder()
+ pbOpenStreamList.setShuffleKey(shuffleKey)
+ workerRequestMap.put(
+ hostPort,
+ (client, new util.ArrayList[PartitionLocation],
pbOpenStreamList))
+ } catch {
+ case ex: Exception =>
+
shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort, ex)
+ logWarning(
+ s"Failed to create client for $shuffleKey-$partitionId from
host: ${location.hostAndFetchPort}. " +
+ s"Shuffle reader will try its replica if exists.")
+ }
+ }
+ workerRequestMap.get(hostPort) match {
+ case (_, locArr, pbOpenStreamListBuilder) =>
+ locArr.add(location)
+ pbOpenStreamListBuilder.addFileName(location.getFileName)
+ .addStartIndex(startMapIndex)
+ .addEndIndex(endMapIndex)
+ pbOpenStreamListBuilder.addReadLocalShuffle(
+ localFetchEnabled && location.getHost.equals(localHostAddress))
+ case _ =>
+ logDebug(s"Empty client for host ${hostPort}")
}
- val (_, locArr, pbOpenStreamListBuilder) =
workerRequestMap.get(hostPort)
-
- locArr.add(location)
- pbOpenStreamListBuilder.addFileName(location.getFileName)
- .addStartIndex(startMapIndex)
- .addEndIndex(endMapIndex)
- pbOpenStreamListBuilder.addReadLocalShuffle(
- localFetchEnabled && location.getHost.equals(localHostAddress))
}
}
}
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index 07ce7b10e..efa9641f6 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -285,4 +285,6 @@ public abstract class ShuffleClient {
public abstract boolean reportBarrierTaskFailure(int appShuffleId, String
appShuffleIdentifier);
public abstract TransportClientFactory getDataClientFactory();
+
+ public abstract void excludeFailedFetchLocation(String hostAndFetchPort,
Exception e);
}
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index b43bd9598..ed1dabc91 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -117,6 +117,8 @@ public class ShuffleClientImpl extends ShuffleClient {
private final Set<String> pushExcludedWorkers =
ConcurrentHashMap.newKeySet();
private final ConcurrentHashMap<String, Long> fetchExcludedWorkers =
JavaUtils.newConcurrentHashMap();
+ private boolean pushReplicateEnabled;
+ private boolean fetchExcludeWorkerOnFailureEnabled;
private final ExecutorService pushDataRetryPool;
@@ -180,6 +182,8 @@ public class ShuffleClientImpl extends ShuffleClient {
pushBufferMaxSize = conf.clientPushBufferMaxSize();
pushExcludeWorkerOnFailureEnabled =
conf.clientPushExcludeWorkerOnFailureEnabled();
shuffleCompressionEnabled =
!conf.shuffleCompressionCodec().equals(CompressionCodec.NONE);
+ pushReplicateEnabled = conf.clientPushReplicateEnabled();
+ fetchExcludeWorkerOnFailureEnabled =
conf.clientFetchExcludeWorkerOnFailureEnabled();
if (conf.clientPushReplicateEnabled()) {
pushDataTimeout = conf.pushDataTimeoutMs() * 2;
} else {
@@ -1904,4 +1908,12 @@ public class ShuffleClientImpl extends ShuffleClient {
public TransportClientFactory getDataClientFactory() {
return dataClientFactory;
}
+
+ public void excludeFailedFetchLocation(String hostAndFetchPort, Exception e)
{
+ if (pushReplicateEnabled
+ && fetchExcludeWorkerOnFailureEnabled
+ && Utils.isCriticalCauseForFetch(e)) {
+ fetchExcludedWorkers.put(hostAndFetchPort, System.currentTimeMillis());
+ }
+ }
}
diff --git
a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
index bd0164cd6..dfbb7c502 100644
---
a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
+++
b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
@@ -159,8 +159,6 @@ public abstract class CelebornInputStream extends
InputStream {
private final boolean enabledReadLocalShuffle;
private final String localHostAddress;
- private boolean pushReplicateEnabled;
- private boolean fetchExcludeWorkerOnFailureEnabled;
private boolean shuffleCompressionEnabled;
private long fetchExcludedWorkerExpireTimeout;
private ConcurrentHashMap<String, Long> fetchExcludedWorkers;
@@ -205,8 +203,6 @@ public abstract class CelebornInputStream extends
InputStream {
this.rangeReadFilter = conf.shuffleRangeReadFilterEnabled();
this.enabledReadLocalShuffle = conf.enableReadLocalShuffleFile();
this.localHostAddress = Utils.localHostName(conf);
- this.pushReplicateEnabled = conf.clientPushReplicateEnabled();
- this.fetchExcludeWorkerOnFailureEnabled =
conf.clientFetchExcludeWorkerOnFailureEnabled();
this.shuffleCompressionEnabled =
!conf.shuffleCompressionCodec().equals(CompressionCodec.NONE);
this.fetchExcludedWorkerExpireTimeout =
conf.clientFetchExcludedWorkerExpireTimeout();
@@ -299,12 +295,6 @@ public abstract class CelebornInputStream extends
InputStream {
}
}
- private void excludeFailedLocation(PartitionLocation location, Exception
e) {
- if (pushReplicateEnabled && fetchExcludeWorkerOnFailureEnabled &&
isCriticalCause(e)) {
- fetchExcludedWorkers.put(location.hostAndFetchPort(),
System.currentTimeMillis());
- }
- }
-
private boolean isExcluded(PartitionLocation location) {
Long timestamp = fetchExcludedWorkers.get(location.hostAndFetchPort());
if (timestamp == null) {
@@ -354,7 +344,7 @@ public abstract class CelebornInputStream extends
InputStream {
return createReader(location, pbStreamHandler, fetchChunkRetryCnt,
fetchChunkMaxRetry);
} catch (Exception e) {
lastException = e;
- excludeFailedLocation(location, e);
+
shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort(), e);
fetchChunkRetryCnt++;
if (location.hasPeer()) {
// fetchChunkRetryCnt % 2 == 0 means both replicas have been tried,
@@ -392,7 +382,8 @@ public abstract class CelebornInputStream extends
InputStream {
}
return currentReader.next();
} catch (Exception e) {
- excludeFailedLocation(currentReader.getLocation(), e);
+ shuffleClient.excludeFailedFetchLocation(
+ currentReader.getLocation().hostAndFetchPort(), e);
fetchChunkRetryCnt++;
currentReader.close();
if (fetchChunkRetryCnt == fetchChunkMaxRetry) {
diff --git
a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
index 49b6b5c54..a190c3e1b 100644
--- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -192,6 +192,9 @@ public class DummyShuffleClient extends ShuffleClient {
return null;
}
+ @Override
+ public void excludeFailedFetchLocation(String hostAndFetchPort, Exception e)
{}
+
public void initReducePartitionMap(int shuffleId, int numPartitions, int
workerNum) {
ConcurrentHashMap<Integer, PartitionLocation> map =
JavaUtils.newConcurrentHashMap();
String host = "host";
diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
index f6709f696..dc5b6ea34 100644
--- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
@@ -44,7 +44,7 @@ import org.roaringbitmap.RoaringBitmap
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.CelebornConf.PORT_MAX_RETRY
-import org.apache.celeborn.common.exception.CelebornException
+import org.apache.celeborn.common.exception.{CelebornException,
CelebornIOException}
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{DiskStatus, WorkerInfo}
import org.apache.celeborn.common.network.protocol.TransportMessage
@@ -1343,4 +1343,16 @@ object Utils extends Logging {
throw e
}
}
+
+ def isCriticalCauseForFetch(e: Exception) = {
+ val rpcTimeout =
+ e.isInstanceOf[IOException] && e.getCause != null &&
e.getCause.isInstanceOf[TimeoutException]
+ val connectException =
+ e.isInstanceOf[CelebornIOException] && e.getMessage != null &&
(e.getMessage.startsWith(
+ "Connecting to") || e.getMessage.startsWith("Failed to"))
+ val fetchChunkTimeout = e.isInstanceOf[
+ CelebornIOException] && e.getCause != null &&
e.getCause.isInstanceOf[IOException]
+ connectException || rpcTimeout || fetchChunkTimeout
+ }
+
}