This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new 93289a5dc92 Revert "[SPARK-38916][CORE] Tasks not killed caused by
race conditions between killTask() and launchTask()"
93289a5dc92 is described below
commit 93289a5dc929c97d10df03853161d5b931538ba5
Author: Wenchen Fan <[email protected]>
AuthorDate: Mon Apr 25 11:57:51 2022 +0800
Revert "[SPARK-38916][CORE] Tasks not killed caused by race conditions
between killTask() and launchTask()"
This reverts commit 9dd64d40c91253c275fef2313c6a326ef72112cb.
---
.../scala/org/apache/spark/executor/Executor.scala | 51 +-----
.../CoarseGrainedExecutorBackendSuite.scala | 185 +--------------------
.../org/apache/spark/executor/ExecutorSuite.scala | 10 +-
3 files changed, 16 insertions(+), 230 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 4c84224dd05..3f1023e3491 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -83,7 +83,7 @@ private[spark] class Executor(
private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
- private[executor] val conf = env.conf
+ private val conf = env.conf
// No ip or host:port - just hostname
Utils.checkHost(executorHostname)
@@ -104,7 +104,7 @@ private[spark] class Executor(
// Use UninterruptibleThread to run tasks so that we can allow running codes
without being
// interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894,
HADOOP-10622,
// will hang forever if some methods are interrupted.
- private[executor] val threadPool = {
+ private val threadPool = {
val threadFactory = new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat("Executor task launch worker-%d")
@@ -174,33 +174,7 @@ private[spark] class Executor(
private val maxResultSize = conf.get(MAX_RESULT_SIZE)
// Maintains the list of running tasks.
- private[executor] val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
-
- // Kill mark TTL in milliseconds - 10 seconds.
- private val KILL_MARK_TTL_MS = 10000L
-
- // Kill marks with interruptThread flag, kill reason and timestamp.
- // This is to avoid dropping the kill event when killTask() is called before
launchTask().
- private[executor] val killMarks = new ConcurrentHashMap[Long, (Boolean,
String, Long)]
-
- private val killMarkCleanupTask = new Runnable {
- override def run(): Unit = {
- val oldest = System.currentTimeMillis() - KILL_MARK_TTL_MS
- val iter = killMarks.entrySet().iterator()
- while (iter.hasNext) {
- if (iter.next().getValue._3 < oldest) {
- iter.remove()
- }
- }
- }
- }
-
- // Kill mark cleanup thread executor.
- private val killMarkCleanupService =
-
ThreadUtils.newDaemonSingleThreadScheduledExecutor("executor-kill-mark-cleanup")
-
- killMarkCleanupService.scheduleAtFixedRate(
- killMarkCleanupTask, KILL_MARK_TTL_MS, KILL_MARK_TTL_MS,
TimeUnit.MILLISECONDS)
+ private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
/**
* When an executor is unable to send heartbeats to the driver more than
`HEARTBEAT_MAX_FAILURES`
@@ -290,18 +264,9 @@ private[spark] class Executor(
decommissioned = true
}
- private[executor] def createTaskRunner(context: ExecutorBackend,
- taskDescription: TaskDescription) = new TaskRunner(context,
taskDescription, plugins)
-
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription):
Unit = {
- val taskId = taskDescription.taskId
- val tr = createTaskRunner(context, taskDescription)
- runningTasks.put(taskId, tr)
- val killMark = killMarks.get(taskId)
- if (killMark != null) {
- tr.kill(killMark._1, killMark._2)
- killMarks.remove(taskId)
- }
+ val tr = new TaskRunner(context, taskDescription, plugins)
+ runningTasks.put(taskDescription.taskId, tr)
threadPool.execute(tr)
if (decommissioned) {
log.error(s"Launching a task while in decommissioned state.")
@@ -309,7 +274,6 @@ private[spark] class Executor(
}
def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit =
{
- killMarks.put(taskId, (interruptThread, reason,
System.currentTimeMillis()))
val taskRunner = runningTasks.get(taskId)
if (taskRunner != null) {
if (taskReaperEnabled) {
@@ -332,8 +296,6 @@ private[spark] class Executor(
} else {
taskRunner.kill(interruptThread = interruptThread, reason = reason)
}
- // Safe to remove kill mark as we got a chance with the TaskRunner.
- killMarks.remove(taskId)
}
}
@@ -372,9 +334,6 @@ private[spark] class Executor(
if (threadPool != null) {
threadPool.shutdown()
}
- if (killMarkCleanupService != null) {
- killMarkCleanupService.shutdown()
- }
if (replClassLoader != null && plugins != null) {
// Notify plugins that executor is shutting down so they can terminate
cleanly
Utils.withContextClassLoader(replClassLoader) {
diff --git
a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
index 5210990f3b9..4909a586d31 100644
---
a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
+++
b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
@@ -21,17 +21,14 @@ import java.io.File
import java.net.URL
import java.nio.ByteBuffer
import java.util.Properties
-import java.util.concurrent.ConcurrentHashMap
-import scala.collection.concurrent.TrieMap
import scala.collection.mutable
import scala.concurrent.duration._
import org.json4s.{DefaultFormats, Extraction}
import org.json4s.JsonAST.{JArray, JObject}
import org.json4s.JsonDSL._
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito._
+import org.mockito.Mockito.when
import org.scalatest.concurrent.Eventually.{eventually, timeout}
import org.scalatestplus.mockito.MockitoSugar
@@ -42,9 +39,9 @@ import org.apache.spark.resource.ResourceUtils._
import org.apache.spark.resource.TestResourceIDs._
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.TaskDescription
-import
org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{KillTask,
LaunchTask}
+import
org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.LaunchTask
import org.apache.spark.serializer.JavaSerializer
-import org.apache.spark.util.{SerializableBuffer, ThreadUtils, Utils}
+import org.apache.spark.util.{SerializableBuffer, Utils}
class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
with LocalSparkContext with MockitoSugar {
@@ -360,182 +357,6 @@ class CoarseGrainedExecutorBackendSuite extends
SparkFunSuite
assert(arg.bindAddress == "bindaddress1")
}
- /**
- * This testcase is to verify that [[Executor.killTask()]] will always
cancel a task that is
- * being executed in [[Executor.TaskRunner]].
- */
- test(s"Tasks launched should always be cancelled.") {
- val conf = new SparkConf
- val securityMgr = new SecurityManager(conf)
- val serializer = new JavaSerializer(conf)
- val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor")
- var backend: CoarseGrainedExecutorBackend = null
-
- try {
- val rpcEnv = RpcEnv.create("1", "localhost", 0, conf, securityMgr)
- val env = createMockEnv(conf, serializer, Some(rpcEnv))
- backend = new CoarseGrainedExecutorBackend(env.rpcEnv,
rpcEnv.address.hostPort, "1",
- "host1", "host1", 4, env, None,
- resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf))
-
- backend.rpcEnv.setupEndpoint("Executor 1", backend)
- backend.executor = mock[Executor](CALLS_REAL_METHODS)
- val executor = backend.executor
- // Mock the executor.
- when(executor.threadPool).thenReturn(threadPool)
- val runningTasks = spy(new ConcurrentHashMap[Long, Executor#TaskRunner])
- when(executor.runningTasks).thenAnswer(_ => runningTasks)
- when(executor.conf).thenReturn(conf)
-
- // We don't really verify the data, just pass it around.
- val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
-
- val numTasks = 1000
- val tasksKilled = new TrieMap[Long, Boolean]()
- val tasksExecuted = new TrieMap[Long, Boolean]()
-
- // Fake tasks with different taskIds.
- val taskDescriptions = (1 to numTasks).map {
- taskId => new TaskDescription(taskId, 2, "1", "TASK ${taskId}", 19,
- 1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new
Properties, 1,
- Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data)
- }
- assert(taskDescriptions.length == numTasks)
-
- def getFakeTaskRunner(taskDescription: TaskDescription):
Executor#TaskRunner = {
- new executor.TaskRunner(backend, taskDescription, None) {
- override def run(): Unit = {
- tasksExecuted.put(taskDescription.taskId, true)
- logInfo(s"task ${taskDescription.taskId} runs.")
- }
-
- override def kill(interruptThread: Boolean, reason: String): Unit = {
- logInfo(s"task ${taskDescription.taskId} killed.")
- tasksKilled.put(taskDescription.taskId, true)
- }
- }
- }
-
- // Feed the fake task-runners to be executed by the executor.
- val firstLaunchTask = getFakeTaskRunner(taskDescriptions(1))
- val otherTasks = taskDescriptions.slice(1,
numTasks).map(getFakeTaskRunner(_)).toArray
- assert (otherTasks.length == numTasks - 1)
- // Workaround for compilation issue around Mockito.doReturn
- doReturn(firstLaunchTask, otherTasks: _*).when(executor).
- createTaskRunner(any(), any())
-
- // Launch tasks and quickly kill them so that TaskRunner.killTask will
be triggered.
- taskDescriptions.foreach { taskDescription =>
- val buffer = new
SerializableBuffer(TaskDescription.encode(taskDescription))
- backend.self.send(LaunchTask(buffer))
- Thread.sleep(1)
- backend.self.send(KillTask(taskDescription.taskId, "exec1", false,
"test"))
- }
-
- eventually(timeout(10.seconds)) {
- verify(runningTasks, times(numTasks)).put(any(), any())
- }
-
- assert(tasksExecuted.size == tasksKilled.size,
- s"Tasks killed ${tasksKilled.size} != tasks executed
${tasksExecuted.size}")
- assert(tasksExecuted.keySet == tasksKilled.keySet)
- logInfo(s"Task executed ${tasksExecuted.size}, task killed
${tasksKilled.size}")
- } finally {
- if (backend != null) {
- backend.rpcEnv.shutdown()
- }
- threadPool.shutdownNow()
- }
- }
-
- /**
- * This testcase is to verify that [[Executor.killTask()]] will always
cancel a task even if
- * it has not been launched yet.
- */
- test(s"Tasks not launched should always be cancelled.") {
- val conf = new SparkConf
- val securityMgr = new SecurityManager(conf)
- val serializer = new JavaSerializer(conf)
- val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor")
- var backend: CoarseGrainedExecutorBackend = null
-
- try {
- val rpcEnv = RpcEnv.create("1", "localhost", 0, conf, securityMgr)
- val env = createMockEnv(conf, serializer, Some(rpcEnv))
- backend = new CoarseGrainedExecutorBackend(env.rpcEnv,
rpcEnv.address.hostPort, "1",
- "host1", "host1", 4, env, None,
- resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf))
-
- backend.rpcEnv.setupEndpoint("Executor 1", backend)
- backend.executor = mock[Executor](CALLS_REAL_METHODS)
- val executor = backend.executor
- // Mock the executor.
- when(executor.threadPool).thenReturn(threadPool)
- val runningTasks = spy(new ConcurrentHashMap[Long, Executor#TaskRunner])
- when(executor.runningTasks).thenAnswer(_ => runningTasks)
- when(executor.conf).thenReturn(conf)
-
- // We don't really verify the data, just pass it around.
- val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
-
- val numTasks = 1000
- val tasksKilled = new TrieMap[Long, Boolean]()
- val tasksExecuted = new TrieMap[Long, Boolean]()
-
- // Fake tasks with different taskIds.
- val taskDescriptions = (1 to numTasks).map {
- taskId => new TaskDescription(taskId, 2, "1", "TASK ${taskId}", 19,
- 1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new
Properties, 1,
- Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data)
- }
- assert(taskDescriptions.length == numTasks)
-
- def getFakeTaskRunner(taskDescription: TaskDescription):
Executor#TaskRunner = {
- new executor.TaskRunner(backend, taskDescription, None) {
- override def run(): Unit = {
- tasksExecuted.put(taskDescription.taskId, true)
- logInfo(s"task ${taskDescription.taskId} runs.")
- }
-
- override def kill(interruptThread: Boolean, reason: String): Unit = {
- logInfo(s"task ${taskDescription.taskId} killed.")
- tasksKilled.put(taskDescription.taskId, true)
- }
- }
- }
-
- // Feed the fake task-runners to be executed by the executor.
- val firstLaunchTask = getFakeTaskRunner(taskDescriptions(1))
- val otherTasks = taskDescriptions.slice(1,
numTasks).map(getFakeTaskRunner(_)).toArray
- assert (otherTasks.length == numTasks - 1)
- // Workaround for compilation issue around Mockito.doReturn
- doReturn(firstLaunchTask, otherTasks: _*).when(executor).
- createTaskRunner(any(), any())
-
- // The reverse order of events can happen when the scheduler tries to
cancel a task right
- // after launching it.
- taskDescriptions.foreach { taskDescription =>
- val buffer = new
SerializableBuffer(TaskDescription.encode(taskDescription))
- backend.self.send(KillTask(taskDescription.taskId, "exec1", false,
"test"))
- backend.self.send(LaunchTask(buffer))
- }
-
- eventually(timeout(10.seconds)) {
- verify(runningTasks, times(numTasks)).put(any(), any())
- }
-
- assert(tasksExecuted.size == tasksKilled.size,
- s"Tasks killed ${tasksKilled.size} != tasks executed
${tasksExecuted.size}")
- assert(tasksExecuted.keySet == tasksKilled.keySet)
- logInfo(s"Task executed ${tasksExecuted.size}, task killed
${tasksKilled.size}")
- } finally {
- if (backend != null) {
- backend.rpcEnv.shutdown()
- }
- threadPool.shutdownNow()
- }
- }
-
private def createMockEnv(conf: SparkConf, serializer: JavaSerializer,
rpcEnv: Option[RpcEnv] = None): SparkEnv = {
val mockEnv = mock[SparkEnv]
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index 7f7b10c8c33..a237447b0fa 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -22,7 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler
import java.net.URL
import java.nio.ByteBuffer
import java.util.Properties
-import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.immutable
@@ -321,7 +321,13 @@ class ExecutorSuite extends SparkFunSuite
nonZeroAccumulator.add(1)
metrics.registerAccumulator(nonZeroAccumulator)
- val tasksMap = executor.runningTasks
+ val executorClass = classOf[Executor]
+ val tasksMap = {
+ val field =
+
executorClass.getDeclaredField("org$apache$spark$executor$Executor$$runningTasks")
+ field.setAccessible(true)
+ field.get(executor).asInstanceOf[ConcurrentHashMap[Long,
executor.TaskRunner]]
+ }
val mockTaskRunner = mock[executor.TaskRunner]
val mockTask = mock[Task[Any]]
when(mockTask.metrics).thenReturn(metrics)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]