This is an automated email from the ASF dual-hosted git repository.
wuyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new bc80c84 [SPARK-36575][CORE] Should ignore task finished event if its
task set is gone in TaskSchedulerImpl.handleSuccessfulTask
bc80c84 is described below
commit bc80c844fcb37d8d699d46bb34edadb98ed0d9f7
Author: hujiahua <[email protected]>
AuthorDate: Wed Nov 10 11:20:35 2021 +0800
[SPARK-36575][CORE] Should ignore task finished event if its task set is
gone in TaskSchedulerImpl.handleSuccessfulTask
### What changes were proposed in this pull request?
When a executor finished a task of some stage, the driver will receive a
`StatusUpdate` event to handle it. At the same time the driver found the
executor heartbeat timed out, so the dirver also need handle ExecutorLost event
simultaneously. There was a race condition issues here, which will make
`TaskSetManager.successful` and `TaskSetManager.tasksSuccessful` wrong result.
The problem is that `TaskResultGetter.enqueueSuccessfulTask` use
asynchronous thread to handle successful task, that mean the synchronized lock
of `TaskSchedulerImpl` was released prematurely during midway
https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala#L61.
So `TaskSchedulerImpl` may handle executorLost first, then the asynchronous
thread will go on to handle successful task. It cause
`TaskSetManager.successful` and `T [...]
### Why are the changes needed?
It will cause `TaskSetManager.successful` and
`TaskSetManager.tasksSuccessful` wrong result.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Add a new test.
Closes #33872 from sleep1661/SPARK-36575.
Lead-authored-by: hujiahua <[email protected]>
Co-authored-by: MattHu <[email protected]>
Signed-off-by: yi.wu <[email protected]>
---
.../apache/spark/scheduler/TaskSchedulerImpl.scala | 8 +-
.../spark/scheduler/TaskSchedulerImplSuite.scala | 86 +++++++++++++++++++++-
2 files changed, 92 insertions(+), 2 deletions(-)
diff --git
a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 55db73a..282f12b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -871,7 +871,13 @@ private[spark] class TaskSchedulerImpl(
taskSetManager: TaskSetManager,
tid: Long,
taskResult: DirectTaskResult[_]): Unit = synchronized {
- taskSetManager.handleSuccessfulTask(tid, taskResult)
+ if (taskIdToTaskSetManager.contains(tid)) {
+ taskSetManager.handleSuccessfulTask(tid, taskResult)
+ } else {
+ logInfo(s"Ignoring update with state finished for task (TID $tid)
because its task set " +
+ "is gone (this is likely the result of receiving duplicate task
finished status updates)" +
+ " or its executor has been marked as failed.")
+ }
}
def handleFailedTask(
diff --git
a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 53dc14c..551d55d 100644
---
a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++
b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -18,9 +18,12 @@
package org.apache.spark.scheduler
import java.nio.ByteBuffer
+import java.util.Properties
+import java.util.concurrent.{CountDownLatch, ExecutorService,
LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.duration._
+import scala.language.reflectiveCalls
import org.mockito.ArgumentMatchers.{any, anyInt, anyString, eq => meq}
import org.mockito.Mockito.{atLeast, atMost, never, spy, times, verify, when}
@@ -34,7 +37,7 @@ import org.apache.spark.internal.config
import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile,
TaskResourceRequests}
import org.apache.spark.resource.ResourceUtils._
import org.apache.spark.resource.TestResourceIDs._
-import org.apache.spark.util.{Clock, ManualClock}
+import org.apache.spark.util.{Clock, ManualClock, ThreadUtils}
class FakeSchedulerBackend extends SchedulerBackend {
def start(): Unit = {}
@@ -1995,6 +1998,87 @@ class TaskSchedulerImplSuite extends SparkFunSuite with
LocalSparkContext with B
assert(!normalTSM.runningTasksSet.contains(taskId))
}
+ test("SPARK-36575: Should ignore task finished event if its task set is gone
" +
+ "in TaskSchedulerImpl.handleSuccessfulTask") {
+ val taskScheduler = setupScheduler()
+
+ val latch = new CountDownLatch(2)
+ val resultGetter = new TaskResultGetter(sc.env, taskScheduler) {
+ override protected val getTaskResultExecutor: ExecutorService =
+ new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS, new
LinkedBlockingQueue[Runnable],
+ ThreadUtils.namedThreadFactory("task-result-getter")) {
+ override def execute(command: Runnable): Unit = {
+ super.execute(new Runnable {
+ override def run(): Unit = {
+ command.run()
+ latch.countDown()
+ }
+ })
+ }
+ }
+ def taskResultExecutor() : ExecutorService = getTaskResultExecutor
+ }
+ taskScheduler.taskResultGetter = resultGetter
+
+ val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1),
+ new WorkerOffer("executor1", "host1", 1))
+ val task1 = new ShuffleMapTask(1, 0, null, new Partition {
+ override def index: Int = 0
+ }, Seq(TaskLocation("host0", "executor0")), new Properties, null)
+
+ val task2 = new ShuffleMapTask(1, 0, null, new Partition {
+ override def index: Int = 1
+ }, Seq(TaskLocation("host1", "executor1")), new Properties, null)
+
+ val taskSet = new TaskSet(Array(task1, task2), 0, 0, 0, null, 0)
+
+ taskScheduler.submitTasks(taskSet)
+ val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
+ assert(2 === taskDescriptions.length)
+
+ val ser = sc.env.serializer.newInstance()
+ val directResult = new DirectTaskResult[Int](ser.serialize(1), Seq(),
Array.empty)
+ val resultBytes = ser.serialize(directResult)
+
+ val busyTask = new Runnable {
+ val lock : Object = new Object
+ override def run(): Unit = {
+ lock.synchronized {
+ lock.wait()
+ }
+ }
+ def markTaskDone: Unit = {
+ lock.synchronized {
+ lock.notify()
+ }
+ }
+ }
+ // make getTaskResultExecutor busy
+ resultGetter.taskResultExecutor().submit(busyTask)
+
+ // task1 finished
+ val tid = taskDescriptions(0).taskId
+ taskScheduler.statusUpdate(
+ tid = tid,
+ state = TaskState.FINISHED,
+ serializedData = resultBytes
+ )
+
+ // mark executor heartbeat timed out
+ taskScheduler.executorLost(taskDescriptions(0).executorId,
ExecutorProcessLost("Executor " +
+ "heartbeat timed out"))
+
+ busyTask.markTaskDone
+
+ // Wait until all events are processed
+ latch.await()
+
+ val taskSetManager =
taskScheduler.taskIdToTaskSetManager.get(taskDescriptions(1).taskId)
+ assert(taskSetManager != null)
+ assert(0 == taskSetManager.tasksSuccessful)
+ assert(!taskSetManager.successful(taskDescriptions(0).index))
+ }
+
/**
* Used by tests to simulate a task failure. This calls the failure handler
explicitly, to ensure
* that all the state is updated when this method returns. Otherwise,
there's no way to know when
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]