cloud-fan commented on code in PR #56559:
URL: https://github.com/apache/spark/pull/56559#discussion_r3462832986
##########
core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala:
##########
@@ -2921,6 +2921,140 @@ class TaskSetManagerSuite
s"\nCaptured logs:\n${logs.mkString("\n")}")
}
+ test("SPARK-57491: late-arriving speculative ShuffleMapTask marks stale
partitionId") {
+ sc = new SparkContext("local", "test")
+ sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"),
("exec3", "host3"))
+ sc.conf.set(config.SPECULATION_MULTIPLIER, 0.0)
+ sc.conf.set(config.SPECULATION_ENABLED, true)
+
+ val taskSet = FakeTask.createShuffleMapTaskSet(2, 0, 0,
+ Seq(TaskLocation("host1", "exec1")),
+ Seq(TaskLocation("host2", "exec2")))
+ val clock = new ManualClock()
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock
= clock)
+ val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] =
taskSet.tasks.map { task =>
+ task.metrics.internalAccums
+ }
+
+ // Register shuffle in MapOutputTrackerMaster so
detectStalePushIfShuffleTask can find it
+ val mapOutputTrackerMaster = sched.mapOutputTracker
+ val shuffleId = taskSet.shuffleId.get
+ mapOutputTrackerMaster.registerShuffle(shuffleId, 2, 2)
+
+ // Offer resources for 2 tasks to start
+ val task0 = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL)._1.get
+ val task1 = manager.resourceOffer("exec2", "host2", PROCESS_LOCAL)._1.get
+ assert(task0.index === 0)
+ assert(task1.index === 1)
+
+ // Advance clock so tasks have been running long enough (markFinished
requires time > 0)
+ clock.advance(1)
+
+ // Complete task 1 (partition 1) successfully with a MapStatus
+ val mapStatus1 = MapStatus(BlockManagerId("exec2", "host2", 2000),
Array(2L, 2L), mapTaskId = 1)
+ val result1 = createMapStatusTaskResult(mapStatus1, accumUpdatesByTask(1))
+ manager.handleSuccessfulTask(task1.taskId, result1)
+ assert(sched.endedTasks(task1.index) === Success)
+
+ // Advance clock so task 0 has been running long enough for speculation.
+ // checkSpeculatableTasks requires tasks to have been running > 0ms when
threshold is 0.
+ clock.advance(1)
+ assert(manager.checkSpeculatableTasks(0))
+ assert(sched.speculativeTasks.toSet === Set(0))
+
+ // Offer resource to start the speculative attempt for partition 0 on a
different host
+ val specTaskOption = manager.resourceOffer("exec3", "host3", ANY)._1
+ assert(specTaskOption.isDefined, "Expected speculative task to be
launched")
+ val specTask = specTaskOption.get
+ assert(specTask.index === 0)
+ assert(specTask.attemptNumber === 1)
+
+ // Replace backend with mock before completing original task 0, to handle
killTask call
+ sched.backend = mock(classOf[SchedulerBackend])
+ sched.dagScheduler.stop()
+ sched.dagScheduler = mock(classOf[DAGScheduler])
+
+ // Complete original task 0 (partition 0) - this will kill the speculative
attempt
+ val mapStatus0 = MapStatus(BlockManagerId("exec1", "host1", 1000),
Array(1L, 1L), mapTaskId = 0)
+ val result0 = createMapStatusTaskResult(mapStatus0, accumUpdatesByTask(0))
+ manager.handleSuccessfulTask(task0.taskId, result0)
+
+ // Verify no stale pushed map indexes yet (stale is only marked when late
result arrives)
+ assert(mapOutputTrackerMaster.getStaleMapIndexes(shuffleId).isEmpty)
+
+ // Now the speculative attempt's result arrives late. Since task 0 already
succeeded,
+ // handleSuccessfulTask will see successful(0)=true and
killedByOtherAttempt contains
+ // the speculative tid, triggering detectStalePushIfShuffleTask.
+ val specMapStatus = MapStatus(
+ BlockManagerId("exec3", "host3", 3000), Array(3L, 3L), mapTaskId = 999)
+ val specResult = createMapStatusTaskResult(specMapStatus,
accumUpdatesByTask(0))
+ manager.handleSuccessfulTask(specTask.taskId, specResult)
+
+ // Verify that partition 0 is now tracked as stale
+ val staleMapIndexes = mapOutputTrackerMaster.getStaleMapIndexes(shuffleId)
+ assert(staleMapIndexes.contains(0),
+ s"Expected staleMapIndexes to contain mapIndex 0, got $staleMapIndexes")
+ }
+
+ test("SPARK-33235: late-arriving result for finished task marks stale
partitionId") {
Review Comment:
This test drives the `info.finished` path (a duplicate result for the same
tid) and asserts the partition is marked stale. But production
`handleSuccessfulTask` returns early on `info.finished` *without* calling
`detectStalePushIfShuffleTask` — the sole call site is the
`killedByOtherAttempt` branch (line 828), which is exactly the design you
agreed to (winner re-delivery must not mark stale). So this test asserts
behavior the code intentionally does not implement, and it fails CI:
```
Set() did not contain 0 Expected staleMapIndexes to contain mapIndex 0 on
finished-task path, got Set()
```
It looks left over from before the `info.finished` call was removed. Either
delete it, or invert it to assert the finished-task path does **not** mark
stale — which documents the intentional decision.
##########
core/src/main/scala/org/apache/spark/MapOutputTracker.scala:
##########
@@ -105,6 +105,30 @@ private class ShuffleStatus(
*/
private[spark] val checksumMismatchIndices: Set[Int] = Set()
+ /**
+ * Set of stale pushed partition indexes for this shuffle. Each entry is a
partitionId (which
+ * equals mapIndex, not MapStatus.mapId). When task retry or speculation
causes multiple
+ * attempts for the same map output to push, the merger may include data
from a stale attempt.
+ * We record the stale partition indexes here so the reduce side can check
chunkBitmaps and
+ * fallback if stale data is present in a merged block.
+ */
+ private[this] val staleMapIndexes = new java.util.HashSet[Int]()
+
+ /**
+ * Mark a partition as having stale (redundant) push attempts. Called from
TaskSetManager when it
+ * detects that multiple task attempts for the same map output pushed data
to the merger.
+ * @param partitionId the partition index (== mapIndex) of the stale
(redundant) attempt;
Review Comment:
The `@param` tag names a parameter that doesn't exist — the method signature
is `markStalePushedPartition(mapIndex: Int)`.
```suggestion
* @param mapIndex the partition index (== mapIndex) of the stale
(redundant) attempt;
```
##########
core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala:
##########
@@ -166,9 +165,16 @@ private class PushBasedFetchHelper(
meta: MergedBlockMeta): Unit = {
logDebug(s"Received the meta of push-merged block for ($shuffleId,
$shuffleMergeId," +
s" $reduceId) from ${req.address.host}:${req.address.port}")
+ val mergedBlock = ShuffleMergedBlockId(shuffleId, shuffleMergeId,
reduceId)
try {
-
iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId,
shuffleMergeId,
- reduceId, sizeMap((shuffleId, reduceId)), meta.readChunkBitmaps(),
address))
+ val chunkBitmaps = meta.readChunkBitmaps().toIndexedSeq
Review Comment:
`meta.readChunkBitmaps()` is called twice here — once to build
`chunkBitmaps` for `checkStaleMapIdInMergedBlock`, then again when constructing
`PushMergedRemoteMetaFetchResult`. Each call re-deserializes the
RoaringBitmaps, on the common (non-stale) path. Since `bitmaps` is
`Array[RoaringBitmap]` and the check accepts `Seq[RoaringBitmap]`, reuse the
already-read value instead of re-reading. Same pattern on the local meta path
(line 282).
##########
core/src/main/scala/org/apache/spark/TaskContext.scala:
##########
@@ -195,6 +195,24 @@ abstract class TaskContext extends Serializable {
})
}
+ /**
+ * Adds a listener to be invoked after the task's status update has been
sent to the driver.
+ * This is useful for operations that should only begin after the driver has
been notified
+ * of the task's result. For example, push-based shuffle block push can use
this to
+ * ensure the driver processes the task result before any push data reaches
the merger,
+ * avoiding stale data being merged without detection.
+ *
+ * The callback runs on the same executor thread that sends the status
update.
+ */
+ @Experimental
+ def addPostStatusUpdateListener(listener: PostStatusUpdateListener):
TaskContext
Review Comment:
This method is public `@Experimental`, but its parameter trait
`PostStatusUpdateListener` is `private[spark]`. External code can't name the
type, so the method is uncallable outside Spark while still leaking an internal
trait into the public API signature. Since the mechanism is entirely internal
to push-based shuffle, make the method `private[spark]` to match the trait.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]