This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new 9dd64d40c91 [SPARK-38916][CORE] Tasks not killed caused by race
conditions between killTask() and launchTask()
9dd64d40c91 is described below
commit 9dd64d40c91253c275fef2313c6a326ef72112cb
Author: Maryann Xue <[email protected]>
AuthorDate: Thu Apr 21 16:30:54 2022 +0800
[SPARK-38916][CORE] Tasks not killed caused by race conditions between
killTask() and launchTask()
### What changes were proposed in this pull request?
This PR fixes the race conditions between the killTask() call and the
launchTask() call that sometimes causes tasks not to be killed properly. If
killTask() probes the map of pendingTasksLaunches before launchTask() has had a
chance to put the corresponding task into that map, the kill flag will be lost
and the subsequent launchTask() call will just proceed and run that task
without knowing this task should be killed instead. The fix adds a kill mark
during the killTask() call so that [...]
### Why are the changes needed?
Bug fix.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added UTs.
Closes #36238 from maryannxue/spark-38916.
Authored-by: Maryann Xue <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit bb5092b9af60afdceeccb239d14be660f77ae0ea)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../scala/org/apache/spark/executor/Executor.scala | 51 +++++-
.../CoarseGrainedExecutorBackendSuite.scala | 185 ++++++++++++++++++++-
.../org/apache/spark/executor/ExecutorSuite.scala | 10 +-
3 files changed, 230 insertions(+), 16 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 3f1023e3491..4c84224dd05 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -83,7 +83,7 @@ private[spark] class Executor(
private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
- private val conf = env.conf
+ private[executor] val conf = env.conf
// No ip or host:port - just hostname
Utils.checkHost(executorHostname)
@@ -104,7 +104,7 @@ private[spark] class Executor(
// Use UninterruptibleThread to run tasks so that we can allow running codes
without being
// interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894,
HADOOP-10622,
// will hang forever if some methods are interrupted.
- private val threadPool = {
+ private[executor] val threadPool = {
val threadFactory = new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat("Executor task launch worker-%d")
@@ -174,7 +174,33 @@ private[spark] class Executor(
private val maxResultSize = conf.get(MAX_RESULT_SIZE)
// Maintains the list of running tasks.
- private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
+ private[executor] val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
+
+ // Kill mark TTL in milliseconds - 10 seconds.
+ private val KILL_MARK_TTL_MS = 10000L
+
+ // Kill marks with interruptThread flag, kill reason and timestamp.
+ // This is to avoid dropping the kill event when killTask() is called before
launchTask().
+ private[executor] val killMarks = new ConcurrentHashMap[Long, (Boolean,
String, Long)]
+
+ private val killMarkCleanupTask = new Runnable {
+ override def run(): Unit = {
+ val oldest = System.currentTimeMillis() - KILL_MARK_TTL_MS
+ val iter = killMarks.entrySet().iterator()
+ while (iter.hasNext) {
+ if (iter.next().getValue._3 < oldest) {
+ iter.remove()
+ }
+ }
+ }
+ }
+
+ // Kill mark cleanup thread executor.
+ private val killMarkCleanupService =
+
ThreadUtils.newDaemonSingleThreadScheduledExecutor("executor-kill-mark-cleanup")
+
+ killMarkCleanupService.scheduleAtFixedRate(
+ killMarkCleanupTask, KILL_MARK_TTL_MS, KILL_MARK_TTL_MS,
TimeUnit.MILLISECONDS)
/**
* When an executor is unable to send heartbeats to the driver more than
`HEARTBEAT_MAX_FAILURES`
@@ -264,9 +290,18 @@ private[spark] class Executor(
decommissioned = true
}
+ private[executor] def createTaskRunner(context: ExecutorBackend,
+ taskDescription: TaskDescription) = new TaskRunner(context,
taskDescription, plugins)
+
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription):
Unit = {
- val tr = new TaskRunner(context, taskDescription, plugins)
- runningTasks.put(taskDescription.taskId, tr)
+ val taskId = taskDescription.taskId
+ val tr = createTaskRunner(context, taskDescription)
+ runningTasks.put(taskId, tr)
+ val killMark = killMarks.get(taskId)
+ if (killMark != null) {
+ tr.kill(killMark._1, killMark._2)
+ killMarks.remove(taskId)
+ }
threadPool.execute(tr)
if (decommissioned) {
log.error(s"Launching a task while in decommissioned state.")
@@ -274,6 +309,7 @@ private[spark] class Executor(
}
def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit =
{
+ killMarks.put(taskId, (interruptThread, reason,
System.currentTimeMillis()))
val taskRunner = runningTasks.get(taskId)
if (taskRunner != null) {
if (taskReaperEnabled) {
@@ -296,6 +332,8 @@ private[spark] class Executor(
} else {
taskRunner.kill(interruptThread = interruptThread, reason = reason)
}
+ // Safe to remove kill mark as we got a chance with the TaskRunner.
+ killMarks.remove(taskId)
}
}
@@ -334,6 +372,9 @@ private[spark] class Executor(
if (threadPool != null) {
threadPool.shutdown()
}
+ if (killMarkCleanupService != null) {
+ killMarkCleanupService.shutdown()
+ }
if (replClassLoader != null && plugins != null) {
// Notify plugins that executor is shutting down so they can terminate
cleanly
Utils.withContextClassLoader(replClassLoader) {
diff --git
a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
index 4909a586d31..5210990f3b9 100644
---
a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
+++
b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
@@ -21,14 +21,17 @@ import java.io.File
import java.net.URL
import java.nio.ByteBuffer
import java.util.Properties
+import java.util.concurrent.ConcurrentHashMap
+import scala.collection.concurrent.TrieMap
import scala.collection.mutable
import scala.concurrent.duration._
import org.json4s.{DefaultFormats, Extraction}
import org.json4s.JsonAST.{JArray, JObject}
import org.json4s.JsonDSL._
-import org.mockito.Mockito.when
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito._
import org.scalatest.concurrent.Eventually.{eventually, timeout}
import org.scalatestplus.mockito.MockitoSugar
@@ -39,9 +42,9 @@ import org.apache.spark.resource.ResourceUtils._
import org.apache.spark.resource.TestResourceIDs._
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.TaskDescription
-import
org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.LaunchTask
+import
org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{KillTask,
LaunchTask}
import org.apache.spark.serializer.JavaSerializer
-import org.apache.spark.util.{SerializableBuffer, Utils}
+import org.apache.spark.util.{SerializableBuffer, ThreadUtils, Utils}
class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
with LocalSparkContext with MockitoSugar {
@@ -357,6 +360,182 @@ class CoarseGrainedExecutorBackendSuite extends
SparkFunSuite
assert(arg.bindAddress == "bindaddress1")
}
+ /**
+ * This testcase is to verify that [[Executor.killTask()]] will always
cancel a task that is
+ * being executed in [[Executor.TaskRunner]].
+ */
+ test(s"Tasks launched should always be cancelled.") {
+ val conf = new SparkConf
+ val securityMgr = new SecurityManager(conf)
+ val serializer = new JavaSerializer(conf)
+ val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor")
+ var backend: CoarseGrainedExecutorBackend = null
+
+ try {
+ val rpcEnv = RpcEnv.create("1", "localhost", 0, conf, securityMgr)
+ val env = createMockEnv(conf, serializer, Some(rpcEnv))
+ backend = new CoarseGrainedExecutorBackend(env.rpcEnv,
rpcEnv.address.hostPort, "1",
+ "host1", "host1", 4, env, None,
+ resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf))
+
+ backend.rpcEnv.setupEndpoint("Executor 1", backend)
+ backend.executor = mock[Executor](CALLS_REAL_METHODS)
+ val executor = backend.executor
+ // Mock the executor.
+ when(executor.threadPool).thenReturn(threadPool)
+ val runningTasks = spy(new ConcurrentHashMap[Long, Executor#TaskRunner])
+ when(executor.runningTasks).thenAnswer(_ => runningTasks)
+ when(executor.conf).thenReturn(conf)
+
+ // We don't really verify the data, just pass it around.
+ val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
+
+ val numTasks = 1000
+ val tasksKilled = new TrieMap[Long, Boolean]()
+ val tasksExecuted = new TrieMap[Long, Boolean]()
+
+ // Fake tasks with different taskIds.
+ val taskDescriptions = (1 to numTasks).map {
+ taskId => new TaskDescription(taskId, 2, "1", "TASK ${taskId}", 19,
+ 1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new
Properties, 1,
+ Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data)
+ }
+ assert(taskDescriptions.length == numTasks)
+
+ def getFakeTaskRunner(taskDescription: TaskDescription):
Executor#TaskRunner = {
+ new executor.TaskRunner(backend, taskDescription, None) {
+ override def run(): Unit = {
+ tasksExecuted.put(taskDescription.taskId, true)
+ logInfo(s"task ${taskDescription.taskId} runs.")
+ }
+
+ override def kill(interruptThread: Boolean, reason: String): Unit = {
+ logInfo(s"task ${taskDescription.taskId} killed.")
+ tasksKilled.put(taskDescription.taskId, true)
+ }
+ }
+ }
+
+ // Feed the fake task-runners to be executed by the executor.
+ val firstLaunchTask = getFakeTaskRunner(taskDescriptions(1))
+ val otherTasks = taskDescriptions.slice(1,
numTasks).map(getFakeTaskRunner(_)).toArray
+ assert (otherTasks.length == numTasks - 1)
+ // Workaround for compilation issue around Mockito.doReturn
+ doReturn(firstLaunchTask, otherTasks: _*).when(executor).
+ createTaskRunner(any(), any())
+
+ // Launch tasks and quickly kill them so that TaskRunner.killTask will
be triggered.
+ taskDescriptions.foreach { taskDescription =>
+ val buffer = new
SerializableBuffer(TaskDescription.encode(taskDescription))
+ backend.self.send(LaunchTask(buffer))
+ Thread.sleep(1)
+ backend.self.send(KillTask(taskDescription.taskId, "exec1", false,
"test"))
+ }
+
+ eventually(timeout(10.seconds)) {
+ verify(runningTasks, times(numTasks)).put(any(), any())
+ }
+
+ assert(tasksExecuted.size == tasksKilled.size,
+ s"Tasks killed ${tasksKilled.size} != tasks executed
${tasksExecuted.size}")
+ assert(tasksExecuted.keySet == tasksKilled.keySet)
+ logInfo(s"Task executed ${tasksExecuted.size}, task killed
${tasksKilled.size}")
+ } finally {
+ if (backend != null) {
+ backend.rpcEnv.shutdown()
+ }
+ threadPool.shutdownNow()
+ }
+ }
+
+ /**
+ * This testcase is to verify that [[Executor.killTask()]] will always
cancel a task even if
+ * it has not been launched yet.
+ */
+ test(s"Tasks not launched should always be cancelled.") {
+ val conf = new SparkConf
+ val securityMgr = new SecurityManager(conf)
+ val serializer = new JavaSerializer(conf)
+ val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor")
+ var backend: CoarseGrainedExecutorBackend = null
+
+ try {
+ val rpcEnv = RpcEnv.create("1", "localhost", 0, conf, securityMgr)
+ val env = createMockEnv(conf, serializer, Some(rpcEnv))
+ backend = new CoarseGrainedExecutorBackend(env.rpcEnv,
rpcEnv.address.hostPort, "1",
+ "host1", "host1", 4, env, None,
+ resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf))
+
+ backend.rpcEnv.setupEndpoint("Executor 1", backend)
+ backend.executor = mock[Executor](CALLS_REAL_METHODS)
+ val executor = backend.executor
+ // Mock the executor.
+ when(executor.threadPool).thenReturn(threadPool)
+ val runningTasks = spy(new ConcurrentHashMap[Long, Executor#TaskRunner])
+ when(executor.runningTasks).thenAnswer(_ => runningTasks)
+ when(executor.conf).thenReturn(conf)
+
+ // We don't really verify the data, just pass it around.
+ val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
+
+ val numTasks = 1000
+ val tasksKilled = new TrieMap[Long, Boolean]()
+ val tasksExecuted = new TrieMap[Long, Boolean]()
+
+ // Fake tasks with different taskIds.
+ val taskDescriptions = (1 to numTasks).map {
+ taskId => new TaskDescription(taskId, 2, "1", "TASK ${taskId}", 19,
+ 1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new
Properties, 1,
+ Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data)
+ }
+ assert(taskDescriptions.length == numTasks)
+
+ def getFakeTaskRunner(taskDescription: TaskDescription):
Executor#TaskRunner = {
+ new executor.TaskRunner(backend, taskDescription, None) {
+ override def run(): Unit = {
+ tasksExecuted.put(taskDescription.taskId, true)
+ logInfo(s"task ${taskDescription.taskId} runs.")
+ }
+
+ override def kill(interruptThread: Boolean, reason: String): Unit = {
+ logInfo(s"task ${taskDescription.taskId} killed.")
+ tasksKilled.put(taskDescription.taskId, true)
+ }
+ }
+ }
+
+ // Feed the fake task-runners to be executed by the executor.
+ val firstLaunchTask = getFakeTaskRunner(taskDescriptions(1))
+ val otherTasks = taskDescriptions.slice(1,
numTasks).map(getFakeTaskRunner(_)).toArray
+ assert (otherTasks.length == numTasks - 1)
+ // Workaround for compilation issue around Mockito.doReturn
+ doReturn(firstLaunchTask, otherTasks: _*).when(executor).
+ createTaskRunner(any(), any())
+
+ // The reverse order of events can happen when the scheduler tries to
cancel a task right
+ // after launching it.
+ taskDescriptions.foreach { taskDescription =>
+ val buffer = new
SerializableBuffer(TaskDescription.encode(taskDescription))
+ backend.self.send(KillTask(taskDescription.taskId, "exec1", false,
"test"))
+ backend.self.send(LaunchTask(buffer))
+ }
+
+ eventually(timeout(10.seconds)) {
+ verify(runningTasks, times(numTasks)).put(any(), any())
+ }
+
+ assert(tasksExecuted.size == tasksKilled.size,
+ s"Tasks killed ${tasksKilled.size} != tasks executed
${tasksExecuted.size}")
+ assert(tasksExecuted.keySet == tasksKilled.keySet)
+ logInfo(s"Task executed ${tasksExecuted.size}, task killed
${tasksKilled.size}")
+ } finally {
+ if (backend != null) {
+ backend.rpcEnv.shutdown()
+ }
+ threadPool.shutdownNow()
+ }
+ }
+
private def createMockEnv(conf: SparkConf, serializer: JavaSerializer,
rpcEnv: Option[RpcEnv] = None): SparkEnv = {
val mockEnv = mock[SparkEnv]
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index a237447b0fa..7f7b10c8c33 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -22,7 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler
import java.net.URL
import java.nio.ByteBuffer
import java.util.Properties
-import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, TimeUnit}
+import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.immutable
@@ -321,13 +321,7 @@ class ExecutorSuite extends SparkFunSuite
nonZeroAccumulator.add(1)
metrics.registerAccumulator(nonZeroAccumulator)
- val executorClass = classOf[Executor]
- val tasksMap = {
- val field =
-
executorClass.getDeclaredField("org$apache$spark$executor$Executor$$runningTasks")
- field.setAccessible(true)
- field.get(executor).asInstanceOf[ConcurrentHashMap[Long,
executor.TaskRunner]]
- }
+ val tasksMap = executor.runningTasks
val mockTaskRunner = mock[executor.TaskRunner]
val mockTask = mock[Task[Any]]
when(mockTask.metrics).thenReturn(metrics)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]