Copilot commented on code in PR #3650:
URL: https://github.com/apache/celeborn/pull/3650#discussion_r3061985144
##########
client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -511,36 +547,32 @@ public static boolean
shouldReportShuffleFetchFailure(long taskId) {
taskInfo.attemptNumber(),
ti.attemptNumber());
hasRunningAttempt = true;
- } else if ("FAILED".equals(ti.status()) ||
"UNKNOWN".equals(ti.status())) {
- // For KILLED state task, Spark does not count the number of
failures
- // For UNKNOWN state task, Spark does count the number of
failures
- // For FAILED state task, Spark decides whether to count the
failure based on the
- // different failure reasons. Since we cannot obtain the failure
- // reason here, we will count all tasks in FAILED state.
- LOG.info(
- "StageId={} index={} taskId={} attempt={} another attempt {}
status={}.",
- stageId,
- taskInfo.index(),
- taskId,
- taskInfo.attemptNumber(),
- ti.attemptNumber(),
- ti.status());
- failedTaskAttempts += 1;
}
}
}
// The following situations should trigger a FetchFailed exception:
- // 1. If failedTaskAttempts >= maxTaskFails
- // 2. If no other taskAttempts are running
- if (failedTaskAttempts >= maxTaskFails || !hasRunningAttempt) {
+ // 1. If total failures (previous failures + current failure) >=
maxTaskFails
+ // 2. If no other taskAttempts are running, trigger a FetchFailed
exception
+ // to keep the same behavior as Spark.
+ // Note: previousFailureCount does NOT include the current failure,
+ // so (previousFailureCount + 1) represents the total failure
count.
+ int previousFailureCount = getTaskFailureCount(taskSetManager,
taskInfo.index());
+ // Fail-safe: if failure count cannot be determined, conservatively
trigger
+ // FetchFailed to avoid silently swallowing the error.
+ if (previousFailureCount < 0) {
+ return true;
+ }
Review Comment:
When getTaskFailureCount() fails (returns < 0),
shouldReportShuffleFetchFailure() immediately returns true. This makes the
pre-check aggressively report FetchFailed even if there are other running
attempts and the retry limit has not been reached, which can reintroduce
premature stage reruns in exactly the scenarios this change is trying to avoid
(e.g., if reflective access to TaskSetManager.numFailures breaks on some Spark
builds). Consider falling back to the previous attempt/status-based counting
(or at least gating on !hasRunningAttempt) instead of unconditional true, and
log at WARN once to avoid error spam if the field is unavailable.
##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -375,36 +411,32 @@ public static boolean
shouldReportShuffleFetchFailure(long taskId) {
taskInfo.attemptNumber(),
ti.attemptNumber());
hasRunningAttempt = true;
- } else if ("FAILED".equals(ti.status()) ||
"UNKNOWN".equals(ti.status())) {
- // For KILLED state task, Spark does not count the number of
failures
- // For UNKNOWN state task, Spark does count the number of
failures
- // For FAILED state task, Spark decides whether to count the
failure based on the
- // different failure reasons. Since we cannot obtain the failure
- // reason here, we will count all tasks in FAILED state.
- logger.info(
- "StageId={} index={} taskId={} attempt={} another attempt {}
status={}.",
- stageId,
- taskInfo.index(),
- taskId,
- taskInfo.attemptNumber(),
- ti.attemptNumber(),
- ti.status());
- failedTaskAttempts += 1;
}
}
}
// The following situations should trigger a FetchFailed exception:
- // 1. If failedTaskAttempts >= maxTaskFails
- // 2. If no other taskAttempts are running
- if (failedTaskAttempts >= maxTaskFails || !hasRunningAttempt) {
+ // 1. If total failures (previous failures + current failure) >=
maxTaskFails
+ // 2. If no other taskAttempts are running, trigger a FetchFailed
exception
+ // to keep the same behavior as Spark.
+ // Note: previousFailureCount does NOT include the current failure,
+ // so (previousFailureCount + 1) represents the total failure
count.
+ int previousFailureCount = getTaskFailureCount(taskSetManager,
taskInfo.index());
+ // Fail-safe: if failure count cannot be determined, conservatively
trigger
+ // FetchFailed to avoid silently swallowing the error.
+ if (previousFailureCount < 0) {
+ return true;
Review Comment:
When getTaskFailureCount() fails (returns < 0),
shouldReportShuffleFetchFailure() immediately returns true. This makes the
pre-check aggressively report FetchFailed even if there are other running
attempts and the retry limit has not been reached, which can reintroduce
premature stage reruns if reflective access to TaskSetManager.numFailures fails
on some Spark builds. Consider falling back to the previous
attempt/status-based counting (or at least gating on !hasRunningAttempt) rather
than unconditional true, and avoid logging this as an error on every call if
the field is unavailable.
```suggestion
// If failure count cannot be determined, fall back to attempt
status based
// behavior instead of aggressively reporting FetchFailed. This
avoids
// premature stage reruns when reflective access to failure counts is
// unavailable, while still reporting the failure when no other
attempt is
// running.
if (previousFailureCount < 0) {
if (!hasRunningAttempt) {
logger.warn(
"StageId={}, index={}, taskId={}, attemptNumber={}: Unable
to determine "
+ "previous failure count, and no other running attempt
exists. "
+ "Reporting shuffle fetch failure.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber());
return true;
} else {
logger.warn(
"StageId={}, index={}, taskId={}, attemptNumber={}: Unable
to determine "
+ "previous failure count, but another attempt is still
running. "
+ "Deferring shuffle fetch failure report.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber());
return false;
}
```
##########
tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala:
##########
@@ -216,6 +216,107 @@ class SparkUtilsSuite extends AnyFunSuite
}
}
+ test("getTaskFailureCount") {
+ assert(SparkUtils.getTaskFailureCount(null, 0) == -1)
+
+ if (Spark3OrNewer) {
+ val sparkConf = new
SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
+ val sparkSession = SparkSession.builder()
+ .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+ .config("spark.sql.shuffle.partitions", 2)
+ .config("spark.celeborn.shuffle.forceFallback.partition.enabled",
false)
+ .config("spark.celeborn.client.spark.stageRerun.enabled", "true")
+ .config(
+ "spark.shuffle.manager",
+ "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
+ .getOrCreate()
+
+ try {
+ val sc = sparkSession.sparkContext
+ val jobThread = new Thread {
+ override def run(): Unit = {
+ try {
+ sc.parallelize(1 to 100, 2)
+ .repartition(1)
+ .mapPartitions { iter =>
+ Thread.sleep(3000)
+ iter
+ }.collect()
+ } catch {
+ case _: InterruptedException =>
+ }
+ }
+ }
+ jobThread.start()
+
+ val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
+ eventually(timeout(3.seconds), interval(100.milliseconds)) {
+ val taskId = 0
+ val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler,
taskId)
+ assert(taskSetManager != null)
+ assert(SparkUtils.getTaskFailureCount(taskSetManager, 0) == 0)
+ assert(SparkUtils.getTaskFailureCount(taskSetManager, -1) == -1)
+ assert(SparkUtils.getTaskFailureCount(taskSetManager, Int.MaxValue)
== -1)
+ }
+
+ sparkSession.sparkContext.cancelAllJobs()
+ jobThread.interrupt()
+ } finally {
+ sparkSession.stop()
+ }
+ }
+ }
+
+ test("getTaskFailureCount after real task failures") {
+ if (Spark3OrNewer) {
+ // local[1,4]: 1 core (sequential execution), max 4 task failures before
stage abort
+ val sparkConf = new
SparkConf().setAppName("rss-demo").setMaster("local[1,4]")
+ val sparkSession = SparkSession.builder()
+ .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+ .config("spark.sql.shuffle.partitions", 2)
+ .config("spark.celeborn.shuffle.forceFallback.partition.enabled",
false)
+ .config("spark.celeborn.client.spark.stageRerun.enabled", "true")
+ .config(
+ "spark.shuffle.manager",
+ "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
+ .getOrCreate()
+
+ try {
+ val sc = sparkSession.sparkContext
+
+ val jobThread = new Thread {
+ override def run(): Unit = {
+ try {
+ sc.parallelize(1 to 10, 1).mapPartitions { iter =>
+ if (TaskContext.get().attemptNumber() < 2) {
+ throw new RuntimeException("Simulated task failure")
+ }
+ Thread.sleep(10000)
+ iter
+ }.collect()
+ } catch {
+ case _: Exception =>
+ }
+ }
+ }
+ jobThread.start()
+
+ val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
+ eventually(timeout(10.seconds), interval(100.milliseconds)) {
+ // taskId 0,1 failed and removed; taskId 2 is the surviving 3rd
attempt
+ val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, 2)
+ assert(taskSetManager != null)
+ assert(SparkUtils.getTaskFailureCount(taskSetManager, 0) == 2)
Review Comment:
This test assumes the third attempt’s taskId will be exactly 2 (after two
failures). Spark task IDs are globally assigned within a SparkContext and
aren’t guaranteed to align with attempt count if any other tasks/stages run
(including internal ones), which can make the test brittle across Spark
versions/configs. Consider deriving the taskId dynamically (e.g., capturing
TaskContext.taskAttemptId() via an accumulator/Promise, or scanning
taskScheduler’s taskIdToTaskSetManager for the active TaskSetManager) instead
of hardcoding 2.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]