cloud-fan commented on code in PR #56559:
URL: https://github.com/apache/spark/pull/56559#discussion_r3462832986


##########
core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala:
##########
@@ -2921,6 +2921,140 @@ class TaskSetManagerSuite
         s"\nCaptured logs:\n${logs.mkString("\n")}")
   }
 
+  test("SPARK-57491: late-arriving speculative ShuffleMapTask marks stale 
partitionId") {
+    sc = new SparkContext("local", "test")
+    sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), 
("exec3", "host3"))
+    sc.conf.set(config.SPECULATION_MULTIPLIER, 0.0)
+    sc.conf.set(config.SPECULATION_ENABLED, true)
+
+    val taskSet = FakeTask.createShuffleMapTaskSet(2, 0, 0,
+      Seq(TaskLocation("host1", "exec1")),
+      Seq(TaskLocation("host2", "exec2")))
+    val clock = new ManualClock()
+    val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock 
= clock)
+    val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = 
taskSet.tasks.map { task =>
+      task.metrics.internalAccums
+    }
+
+    // Register shuffle in MapOutputTrackerMaster so 
detectStalePushIfShuffleTask can find it
+    val mapOutputTrackerMaster = sched.mapOutputTracker
+    val shuffleId = taskSet.shuffleId.get
+    mapOutputTrackerMaster.registerShuffle(shuffleId, 2, 2)
+
+    // Offer resources for 2 tasks to start
+    val task0 = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL)._1.get
+    val task1 = manager.resourceOffer("exec2", "host2", PROCESS_LOCAL)._1.get
+    assert(task0.index === 0)
+    assert(task1.index === 1)
+
+    // Advance clock so tasks have been running long enough (markFinished 
requires time > 0)
+    clock.advance(1)
+
+    // Complete task 1 (partition 1) successfully with a MapStatus
+    val mapStatus1 = MapStatus(BlockManagerId("exec2", "host2", 2000), 
Array(2L, 2L), mapTaskId = 1)
+    val result1 = createMapStatusTaskResult(mapStatus1, accumUpdatesByTask(1))
+    manager.handleSuccessfulTask(task1.taskId, result1)
+    assert(sched.endedTasks(task1.index) === Success)
+
+    // Advance clock so task 0 has been running long enough for speculation.
+    // checkSpeculatableTasks requires tasks to have been running > 0ms when 
threshold is 0.
+    clock.advance(1)
+    assert(manager.checkSpeculatableTasks(0))
+    assert(sched.speculativeTasks.toSet === Set(0))
+
+    // Offer resource to start the speculative attempt for partition 0 on a 
different host
+    val specTaskOption = manager.resourceOffer("exec3", "host3", ANY)._1
+    assert(specTaskOption.isDefined, "Expected speculative task to be 
launched")
+    val specTask = specTaskOption.get
+    assert(specTask.index === 0)
+    assert(specTask.attemptNumber === 1)
+
+    // Replace backend with mock before completing original task 0, to handle 
killTask call
+    sched.backend = mock(classOf[SchedulerBackend])
+    sched.dagScheduler.stop()
+    sched.dagScheduler = mock(classOf[DAGScheduler])
+
+    // Complete original task 0 (partition 0) - this will kill the speculative 
attempt
+    val mapStatus0 = MapStatus(BlockManagerId("exec1", "host1", 1000), 
Array(1L, 1L), mapTaskId = 0)
+    val result0 = createMapStatusTaskResult(mapStatus0, accumUpdatesByTask(0))
+    manager.handleSuccessfulTask(task0.taskId, result0)
+
+    // Verify no stale pushed map indexes yet (stale is only marked when late 
result arrives)
+    assert(mapOutputTrackerMaster.getStaleMapIndexes(shuffleId).isEmpty)
+
+    // Now the speculative attempt's result arrives late. Since task 0 already 
succeeded,
+    // handleSuccessfulTask will see successful(0)=true and 
killedByOtherAttempt contains
+    // the speculative tid, triggering detectStalePushIfShuffleTask.
+    val specMapStatus = MapStatus(
+      BlockManagerId("exec3", "host3", 3000), Array(3L, 3L), mapTaskId = 999)
+    val specResult = createMapStatusTaskResult(specMapStatus, 
accumUpdatesByTask(0))
+    manager.handleSuccessfulTask(specTask.taskId, specResult)
+
+    // Verify that partition 0 is now tracked as stale
+    val staleMapIndexes = mapOutputTrackerMaster.getStaleMapIndexes(shuffleId)
+    assert(staleMapIndexes.contains(0),
+      s"Expected staleMapIndexes to contain mapIndex 0, got $staleMapIndexes")
+  }
+
+  test("SPARK-33235: late-arriving result for finished task marks stale 
partitionId") {

Review Comment:
   This test drives the `info.finished` path (a duplicate result for the same 
tid) and asserts the partition is marked stale. But production 
`handleSuccessfulTask` returns early on `info.finished` *without* calling 
`detectStalePushIfShuffleTask` — the sole call site is the 
`killedByOtherAttempt` branch (line 828), which is exactly the design you 
agreed to (winner re-delivery must not mark stale). So this test asserts 
behavior the code intentionally does not implement, and it fails CI:
   
   ```
   Set() did not contain 0  Expected staleMapIndexes to contain mapIndex 0 on 
finished-task path, got Set()
   ```
   
   It looks left over from before the `info.finished` call was removed. Either 
delete it, or invert it to assert the finished-task path does **not** mark 
stale — which documents the intentional decision.



##########
core/src/main/scala/org/apache/spark/MapOutputTracker.scala:
##########
@@ -105,6 +105,30 @@ private class ShuffleStatus(
    */
   private[spark] val checksumMismatchIndices: Set[Int] = Set()
 
+  /**
+   * Set of stale pushed partition indexes for this shuffle. Each entry is a 
partitionId (which
+   * equals mapIndex, not MapStatus.mapId). When task retry or speculation 
causes multiple
+   * attempts for the same map output to push, the merger may include data 
from a stale attempt.
+   * We record the stale partition indexes here so the reduce side can check 
chunkBitmaps and
+   * fallback if stale data is present in a merged block.
+   */
+  private[this] val staleMapIndexes = new java.util.HashSet[Int]()
+
+  /**
+   * Mark a partition as having stale (redundant) push attempts. Called from 
TaskSetManager when it
+   * detects that multiple task attempts for the same map output pushed data 
to the merger.
+   * @param partitionId the partition index (== mapIndex) of the stale 
(redundant) attempt;

Review Comment:
   The `@param` tag names a parameter that doesn't exist — the method signature 
is `markStalePushedPartition(mapIndex: Int)`.
   ```suggestion
      * @param mapIndex the partition index (== mapIndex) of the stale 
(redundant) attempt;
   ```



##########
core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala:
##########
@@ -166,9 +165,16 @@ private class PushBasedFetchHelper(
           meta: MergedBlockMeta): Unit = {
         logDebug(s"Received the meta of push-merged block for ($shuffleId, 
$shuffleMergeId," +
           s" $reduceId) from ${req.address.host}:${req.address.port}")
+        val mergedBlock = ShuffleMergedBlockId(shuffleId, shuffleMergeId, 
reduceId)
         try {
-          
iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, 
shuffleMergeId,
-            reduceId, sizeMap((shuffleId, reduceId)), meta.readChunkBitmaps(), 
address))
+          val chunkBitmaps = meta.readChunkBitmaps().toIndexedSeq

Review Comment:
   `meta.readChunkBitmaps()` is called twice here — once to build 
`chunkBitmaps` for `checkStaleMapIdInMergedBlock`, then again when constructing 
`PushMergedRemoteMetaFetchResult`. Each call re-deserializes the 
RoaringBitmaps, on the common (non-stale) path. Since `bitmaps` is 
`Array[RoaringBitmap]` and the check accepts `Seq[RoaringBitmap]`, reuse the 
already-read value instead of re-reading. Same pattern on the local meta path 
(line 282).



##########
core/src/main/scala/org/apache/spark/TaskContext.scala:
##########
@@ -195,6 +195,24 @@ abstract class TaskContext extends Serializable {
     })
   }
 
+  /**
+   * Adds a listener to be invoked after the task's status update has been 
sent to the driver.
+   * This is useful for operations that should only begin after the driver has 
been notified
+   * of the task's result. For example, push-based shuffle block push can use 
this to
+   * ensure the driver processes the task result before any push data reaches 
the merger,
+   * avoiding stale data being merged without detection.
+   *
+   * The callback runs on the same executor thread that sends the status 
update.
+   */
+  @Experimental
+  def addPostStatusUpdateListener(listener: PostStatusUpdateListener): 
TaskContext

Review Comment:
   This method is public `@Experimental`, but its parameter trait 
`PostStatusUpdateListener` is `private[spark]`. External code can't name the 
type, so the method is uncallable outside Spark while still leaking an internal 
trait into the public API signature. Since the mechanism is entirely internal 
to push-based shuffle, make the method `private[spark]` to match the trait.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to