This is an automated email from the ASF dual-hosted git repository.
nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 62db42e28 [CELEBORN-2032] Create reader should change to peer by
taskAttemptId
62db42e28 is described below
commit 62db42e2885fb53c6c559a989484c2da716dd3fd
Author: Xianming Lei <[email protected]>
AuthorDate: Wed Oct 22 10:37:51 2025 +0800
[CELEBORN-2032] Create reader should change to peer by taskAttemptId
### What changes were proposed in this pull request?
In the dual-replica scenario, when creating a reader, we should select the
replica based on taskAttemptId. Usually, taskAttempt0 selects primary
partitionLocation, task Attempt1 selects replica partitionLocation, and so on.
This will provide better fault tolerance.
### Why are the changes needed?
Since https://github.com/apache/celeborn/pull/3079, we deleted the code
logic which should use replica data when task attempt is odd, but if the data
of primary partitionLocation is corrupted and CelebornInputStream#fillBuffer
throws exception, such as decompression failure or some other exceptions, the
replica prititionLocation will not be used when the task is retried. In fact,
if taskAttempt1 uses the replica partitionLocation, taskAttempt1 can run
successfully.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing UTs.
Closes #3490 from leixm/CELEBORN-2032.
Authored-by: Xianming Lei <[email protected]>
Signed-off-by: SteNicholas <[email protected]>
---
.../shuffle/celeborn/CelebornShuffleReader.scala | 26 +++++++++++++++-------
1 file changed, 18 insertions(+), 8 deletions(-)
diff --git
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index ff77eecf9..55e036155 100644
---
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -96,6 +96,8 @@ class CelebornShuffleReader[K, C](
private val exceptionRef = new AtomicReference[IOException]
private val stageRerunEnabled = handle.stageRerunEnabled
private val encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(context)
+ private val pushReplicateEnabled = conf.clientPushReplicateEnabled
+ private val preferReplicaRead = context.attemptNumber % 2 == 1
override def read(): Iterator[Product2[K, C]] = {
@@ -241,7 +243,20 @@ class CelebornShuffleReader[K, C](
partitionIdList.foreach { partitionId =>
if (fileGroups.partitionGroups.containsKey(partitionId)) {
- var locations = fileGroups.partitionGroups.get(partitionId)
+ // CELEBORN-2032. For the first time of open stream and
+ // attemptNumber % 2 = 1, we should read the replica data first.
+ val originLocations = fileGroups.partitionGroups.get(partitionId)
+ val hasReplicate = pushReplicateEnabled &&
+ originLocations.asScala.exists(p => p != null && p.hasPeer)
+ var locations =
+ if (preferReplicaRead && hasReplicate) {
+ originLocations.asScala.map { p =>
+ if (p != null && p.hasPeer) p.getPeer else p
+ }.asJava
+ } else {
+ originLocations
+ }
+
if (splitSkewPartitionWithoutMapRange) {
val partitionLocation2ChunkRange =
CelebornPartitionUtil.splitSkewedPartitionLocations(
new JArrayList(locations),
@@ -255,8 +270,8 @@ class CelebornShuffleReader[K, C](
partitionLocation2ChunkRange.containsKey(location.getUniqueId)
}
locations = filterLocations.asJava
- partitionId2PartitionLocations.put(partitionId, locations)
}
+ partitionId2PartitionLocations.put(partitionId, locations)
makeOpenStreamList(locations)
}
}
@@ -299,12 +314,7 @@ class CelebornShuffleReader[K, C](
val streams = JavaUtils.newConcurrentHashMap[Integer,
CelebornInputStream]()
def createInputStream(partitionId: Int): Unit = {
- val locations =
- if (splitSkewPartitionWithoutMapRange) {
- partitionId2PartitionLocations.get(partitionId)
- } else {
- fileGroups.partitionGroups.get(partitionId)
- }
+ val locations = partitionId2PartitionLocations.get(partitionId)
val locationList =
if (null == locations) {