Copilot commented on code in PR #3650:
URL: https://github.com/apache/celeborn/pull/3650#discussion_r3061985144


##########
client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -511,36 +547,32 @@ public static boolean 
shouldReportShuffleFetchFailure(long taskId) {
                   taskInfo.attemptNumber(),
                   ti.attemptNumber());
               hasRunningAttempt = true;
-            } else if ("FAILED".equals(ti.status()) || 
"UNKNOWN".equals(ti.status())) {
-              // For KILLED state task, Spark does not count the number of 
failures
-              // For UNKNOWN state task, Spark does count the number of 
failures
-              // For FAILED state task, Spark decides whether to count the 
failure based on the
-              // different failure reasons. Since we cannot obtain the failure
-              // reason here, we will count all tasks in FAILED state.
-              LOG.info(
-                  "StageId={} index={} taskId={} attempt={} another attempt {} 
status={}.",
-                  stageId,
-                  taskInfo.index(),
-                  taskId,
-                  taskInfo.attemptNumber(),
-                  ti.attemptNumber(),
-                  ti.status());
-              failedTaskAttempts += 1;
             }
           }
         }
         // The following situations should trigger a FetchFailed exception:
-        //  1. If failedTaskAttempts >= maxTaskFails
-        //  2. If no other taskAttempts are running
-        if (failedTaskAttempts >= maxTaskFails || !hasRunningAttempt) {
+        //  1. If total failures (previous failures + current failure) >= 
maxTaskFails
+        //  2. If no other taskAttempts are running, trigger a FetchFailed 
exception
+        //  to keep the same behavior as Spark.
+        // Note: previousFailureCount does NOT include the current failure,
+        //       so (previousFailureCount + 1) represents the total failure 
count.
+        int previousFailureCount = getTaskFailureCount(taskSetManager, 
taskInfo.index());
+        // Fail-safe: if failure count cannot be determined, conservatively 
trigger
+        // FetchFailed to avoid silently swallowing the error.
+        if (previousFailureCount < 0) {
+          return true;
+        }

Review Comment:
   When getTaskFailureCount() fails (returns < 0), 
shouldReportShuffleFetchFailure() immediately returns true. This makes the 
pre-check aggressively report FetchFailed even if there are other running 
attempts and the retry limit has not been reached, which can reintroduce 
premature stage reruns in exactly the scenarios this change is trying to avoid 
(e.g., if reflective access to TaskSetManager.numFailures breaks on some Spark 
builds). Consider falling back to the previous attempt/status-based counting 
(or at least gating on !hasRunningAttempt) instead of unconditional true, and 
log at WARN once to avoid error spam if the field is unavailable.



##########
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -375,36 +411,32 @@ public static boolean 
shouldReportShuffleFetchFailure(long taskId) {
                   taskInfo.attemptNumber(),
                   ti.attemptNumber());
               hasRunningAttempt = true;
-            } else if ("FAILED".equals(ti.status()) || 
"UNKNOWN".equals(ti.status())) {
-              // For KILLED state task, Spark does not count the number of 
failures
-              // For UNKNOWN state task, Spark does count the number of 
failures
-              // For FAILED state task, Spark decides whether to count the 
failure based on the
-              // different failure reasons. Since we cannot obtain the failure
-              // reason here, we will count all tasks in FAILED state.
-              logger.info(
-                  "StageId={} index={} taskId={} attempt={} another attempt {} 
status={}.",
-                  stageId,
-                  taskInfo.index(),
-                  taskId,
-                  taskInfo.attemptNumber(),
-                  ti.attemptNumber(),
-                  ti.status());
-              failedTaskAttempts += 1;
             }
           }
         }
         // The following situations should trigger a FetchFailed exception:
-        //  1. If failedTaskAttempts >= maxTaskFails
-        //  2. If no other taskAttempts are running
-        if (failedTaskAttempts >= maxTaskFails || !hasRunningAttempt) {
+        //  1. If total failures (previous failures + current failure) >= 
maxTaskFails
+        //  2. If no other taskAttempts are running, trigger a FetchFailed 
exception
+        //  to keep the same behavior as Spark.
+        // Note: previousFailureCount does NOT include the current failure,
+        //       so (previousFailureCount + 1) represents the total failure 
count.
+        int previousFailureCount = getTaskFailureCount(taskSetManager, 
taskInfo.index());
+        // Fail-safe: if failure count cannot be determined, conservatively 
trigger
+        // FetchFailed to avoid silently swallowing the error.
+        if (previousFailureCount < 0) {
+          return true;

Review Comment:
   When getTaskFailureCount() fails (returns < 0), 
shouldReportShuffleFetchFailure() immediately returns true. This makes the 
pre-check aggressively report FetchFailed even if there are other running 
attempts and the retry limit has not been reached, which can reintroduce 
premature stage reruns if reflective access to TaskSetManager.numFailures fails 
on some Spark builds. Consider falling back to the previous 
attempt/status-based counting (or at least gating on !hasRunningAttempt) rather 
than unconditional true, and avoid logging this as an error on every call if 
the field is unavailable.
   ```suggestion
           // If failure count cannot be determined, fall back to attempt 
status based
           // behavior instead of aggressively reporting FetchFailed. This 
avoids
           // premature stage reruns when reflective access to failure counts is
           // unavailable, while still reporting the failure when no other 
attempt is
           // running.
           if (previousFailureCount < 0) {
             if (!hasRunningAttempt) {
               logger.warn(
                   "StageId={}, index={}, taskId={}, attemptNumber={}: Unable 
to determine "
                       + "previous failure count, and no other running attempt 
exists. "
                       + "Reporting shuffle fetch failure.",
                   stageId,
                   taskInfo.index(),
                   taskId,
                   taskInfo.attemptNumber());
               return true;
             } else {
               logger.warn(
                   "StageId={}, index={}, taskId={}, attemptNumber={}: Unable 
to determine "
                       + "previous failure count, but another attempt is still 
running. "
                       + "Deferring shuffle fetch failure report.",
                   stageId,
                   taskInfo.index(),
                   taskId,
                   taskInfo.attemptNumber());
               return false;
             }
   ```



##########
tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala:
##########
@@ -216,6 +216,107 @@ class SparkUtilsSuite extends AnyFunSuite
     }
   }
 
+  test("getTaskFailureCount") {
+    assert(SparkUtils.getTaskFailureCount(null, 0) == -1)
+
+    if (Spark3OrNewer) {
+      val sparkConf = new 
SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
+      val sparkSession = SparkSession.builder()
+        .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+        .config("spark.sql.shuffle.partitions", 2)
+        .config("spark.celeborn.shuffle.forceFallback.partition.enabled", 
false)
+        .config("spark.celeborn.client.spark.stageRerun.enabled", "true")
+        .config(
+          "spark.shuffle.manager",
+          "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
+        .getOrCreate()
+
+      try {
+        val sc = sparkSession.sparkContext
+        val jobThread = new Thread {
+          override def run(): Unit = {
+            try {
+              sc.parallelize(1 to 100, 2)
+                .repartition(1)
+                .mapPartitions { iter =>
+                  Thread.sleep(3000)
+                  iter
+                }.collect()
+            } catch {
+              case _: InterruptedException =>
+            }
+          }
+        }
+        jobThread.start()
+
+        val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
+        eventually(timeout(3.seconds), interval(100.milliseconds)) {
+          val taskId = 0
+          val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, 
taskId)
+          assert(taskSetManager != null)
+          assert(SparkUtils.getTaskFailureCount(taskSetManager, 0) == 0)
+          assert(SparkUtils.getTaskFailureCount(taskSetManager, -1) == -1)
+          assert(SparkUtils.getTaskFailureCount(taskSetManager, Int.MaxValue) 
== -1)
+        }
+
+        sparkSession.sparkContext.cancelAllJobs()
+        jobThread.interrupt()
+      } finally {
+        sparkSession.stop()
+      }
+    }
+  }
+
+  test("getTaskFailureCount after real task failures") {
+    if (Spark3OrNewer) {
+      // local[1,4]: 1 core (sequential execution), max 4 task failures before 
stage abort
+      val sparkConf = new 
SparkConf().setAppName("rss-demo").setMaster("local[1,4]")
+      val sparkSession = SparkSession.builder()
+        .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+        .config("spark.sql.shuffle.partitions", 2)
+        .config("spark.celeborn.shuffle.forceFallback.partition.enabled", 
false)
+        .config("spark.celeborn.client.spark.stageRerun.enabled", "true")
+        .config(
+          "spark.shuffle.manager",
+          "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
+        .getOrCreate()
+
+      try {
+        val sc = sparkSession.sparkContext
+
+        val jobThread = new Thread {
+          override def run(): Unit = {
+            try {
+              sc.parallelize(1 to 10, 1).mapPartitions { iter =>
+                if (TaskContext.get().attemptNumber() < 2) {
+                  throw new RuntimeException("Simulated task failure")
+                }
+                Thread.sleep(10000)
+                iter
+              }.collect()
+            } catch {
+              case _: Exception =>
+            }
+          }
+        }
+        jobThread.start()
+
+        val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
+        eventually(timeout(10.seconds), interval(100.milliseconds)) {
+          // taskId 0,1 failed and removed; taskId 2 is the surviving 3rd 
attempt
+          val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, 2)
+          assert(taskSetManager != null)
+          assert(SparkUtils.getTaskFailureCount(taskSetManager, 0) == 2)

Review Comment:
   This test assumes the third attempt’s taskId will be exactly 2 (after two 
failures). Spark task IDs are globally assigned within a SparkContext and 
aren’t guaranteed to align with attempt count if any other tasks/stages run 
(including internal ones), which can make the test brittle across Spark 
versions/configs. Consider deriving the taskId dynamically (e.g., capturing 
TaskContext.taskAttemptId() via an accumulator/Promise, or scanning 
taskScheduler’s taskIdToTaskSetManager for the active TaskSetManager) instead 
of hardcoding 2.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to