This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 93289a5dc92 Revert "[SPARK-38916][CORE] Tasks not killed caused by 
race conditions between killTask() and launchTask()"
93289a5dc92 is described below

commit 93289a5dc929c97d10df03853161d5b931538ba5
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Mon Apr 25 11:57:51 2022 +0800

    Revert "[SPARK-38916][CORE] Tasks not killed caused by race conditions 
between killTask() and launchTask()"
    
    This reverts commit 9dd64d40c91253c275fef2313c6a326ef72112cb.
---
 .../scala/org/apache/spark/executor/Executor.scala |  51 +-----
 .../CoarseGrainedExecutorBackendSuite.scala        | 185 +--------------------
 .../org/apache/spark/executor/ExecutorSuite.scala  |  10 +-
 3 files changed, 16 insertions(+), 230 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala 
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 4c84224dd05..3f1023e3491 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -83,7 +83,7 @@ private[spark] class Executor(
 
   private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
 
-  private[executor] val conf = env.conf
+  private val conf = env.conf
 
   // No ip or host:port - just hostname
   Utils.checkHost(executorHostname)
@@ -104,7 +104,7 @@ private[spark] class Executor(
   // Use UninterruptibleThread to run tasks so that we can allow running codes 
without being
   // interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, 
HADOOP-10622,
   // will hang forever if some methods are interrupted.
-  private[executor] val threadPool = {
+  private val threadPool = {
     val threadFactory = new ThreadFactoryBuilder()
       .setDaemon(true)
       .setNameFormat("Executor task launch worker-%d")
@@ -174,33 +174,7 @@ private[spark] class Executor(
   private val maxResultSize = conf.get(MAX_RESULT_SIZE)
 
   // Maintains the list of running tasks.
-  private[executor] val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
-
-  // Kill mark TTL in milliseconds - 10 seconds.
-  private val KILL_MARK_TTL_MS = 10000L
-
-  // Kill marks with interruptThread flag, kill reason and timestamp.
-  // This is to avoid dropping the kill event when killTask() is called before 
launchTask().
-  private[executor] val killMarks = new ConcurrentHashMap[Long, (Boolean, 
String, Long)]
-
-  private val killMarkCleanupTask = new Runnable {
-    override def run(): Unit = {
-      val oldest = System.currentTimeMillis() - KILL_MARK_TTL_MS
-      val iter = killMarks.entrySet().iterator()
-      while (iter.hasNext) {
-        if (iter.next().getValue._3 < oldest) {
-          iter.remove()
-        }
-      }
-    }
-  }
-
-  // Kill mark cleanup thread executor.
-  private val killMarkCleanupService =
-    
ThreadUtils.newDaemonSingleThreadScheduledExecutor("executor-kill-mark-cleanup")
-
-  killMarkCleanupService.scheduleAtFixedRate(
-    killMarkCleanupTask, KILL_MARK_TTL_MS, KILL_MARK_TTL_MS, 
TimeUnit.MILLISECONDS)
+  private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
 
   /**
    * When an executor is unable to send heartbeats to the driver more than 
`HEARTBEAT_MAX_FAILURES`
@@ -290,18 +264,9 @@ private[spark] class Executor(
     decommissioned = true
   }
 
-  private[executor] def createTaskRunner(context: ExecutorBackend,
-    taskDescription: TaskDescription) = new TaskRunner(context, 
taskDescription, plugins)
-
   def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): 
Unit = {
-    val taskId = taskDescription.taskId
-    val tr = createTaskRunner(context, taskDescription)
-    runningTasks.put(taskId, tr)
-    val killMark = killMarks.get(taskId)
-    if (killMark != null) {
-      tr.kill(killMark._1, killMark._2)
-      killMarks.remove(taskId)
-    }
+    val tr = new TaskRunner(context, taskDescription, plugins)
+    runningTasks.put(taskDescription.taskId, tr)
     threadPool.execute(tr)
     if (decommissioned) {
       log.error(s"Launching a task while in decommissioned state.")
@@ -309,7 +274,6 @@ private[spark] class Executor(
   }
 
   def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = 
{
-    killMarks.put(taskId, (interruptThread, reason, 
System.currentTimeMillis()))
     val taskRunner = runningTasks.get(taskId)
     if (taskRunner != null) {
       if (taskReaperEnabled) {
@@ -332,8 +296,6 @@ private[spark] class Executor(
       } else {
         taskRunner.kill(interruptThread = interruptThread, reason = reason)
       }
-      // Safe to remove kill mark as we got a chance with the TaskRunner.
-      killMarks.remove(taskId)
     }
   }
 
@@ -372,9 +334,6 @@ private[spark] class Executor(
       if (threadPool != null) {
         threadPool.shutdown()
       }
-      if (killMarkCleanupService != null) {
-        killMarkCleanupService.shutdown()
-      }
       if (replClassLoader != null && plugins != null) {
         // Notify plugins that executor is shutting down so they can terminate 
cleanly
         Utils.withContextClassLoader(replClassLoader) {
diff --git 
a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
 
b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
index 5210990f3b9..4909a586d31 100644
--- 
a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
@@ -21,17 +21,14 @@ import java.io.File
 import java.net.URL
 import java.nio.ByteBuffer
 import java.util.Properties
-import java.util.concurrent.ConcurrentHashMap
 
-import scala.collection.concurrent.TrieMap
 import scala.collection.mutable
 import scala.concurrent.duration._
 
 import org.json4s.{DefaultFormats, Extraction}
 import org.json4s.JsonAST.{JArray, JObject}
 import org.json4s.JsonDSL._
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito._
+import org.mockito.Mockito.when
 import org.scalatest.concurrent.Eventually.{eventually, timeout}
 import org.scalatestplus.mockito.MockitoSugar
 
@@ -42,9 +39,9 @@ import org.apache.spark.resource.ResourceUtils._
 import org.apache.spark.resource.TestResourceIDs._
 import org.apache.spark.rpc.RpcEnv
 import org.apache.spark.scheduler.TaskDescription
-import 
org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{KillTask, 
LaunchTask}
+import 
org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.LaunchTask
 import org.apache.spark.serializer.JavaSerializer
-import org.apache.spark.util.{SerializableBuffer, ThreadUtils, Utils}
+import org.apache.spark.util.{SerializableBuffer, Utils}
 
 class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
     with LocalSparkContext with MockitoSugar {
@@ -360,182 +357,6 @@ class CoarseGrainedExecutorBackendSuite extends 
SparkFunSuite
     assert(arg.bindAddress == "bindaddress1")
   }
 
-  /**
-   * This testcase is to verify that [[Executor.killTask()]] will always 
cancel a task that is
-   * being executed in [[Executor.TaskRunner]].
-   */
-  test(s"Tasks launched should always be cancelled.")  {
-    val conf = new SparkConf
-    val securityMgr = new SecurityManager(conf)
-    val serializer = new JavaSerializer(conf)
-    val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor")
-    var backend: CoarseGrainedExecutorBackend = null
-
-    try {
-      val rpcEnv = RpcEnv.create("1", "localhost", 0, conf, securityMgr)
-      val env = createMockEnv(conf, serializer, Some(rpcEnv))
-      backend = new CoarseGrainedExecutorBackend(env.rpcEnv, 
rpcEnv.address.hostPort, "1",
-        "host1", "host1", 4, env, None,
-        resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf))
-
-      backend.rpcEnv.setupEndpoint("Executor 1", backend)
-      backend.executor = mock[Executor](CALLS_REAL_METHODS)
-      val executor = backend.executor
-      // Mock the executor.
-      when(executor.threadPool).thenReturn(threadPool)
-      val runningTasks = spy(new ConcurrentHashMap[Long, Executor#TaskRunner])
-      when(executor.runningTasks).thenAnswer(_ => runningTasks)
-      when(executor.conf).thenReturn(conf)
-
-      // We don't really verify the data, just pass it around.
-      val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
-
-      val numTasks = 1000
-      val tasksKilled = new TrieMap[Long, Boolean]()
-      val tasksExecuted = new TrieMap[Long, Boolean]()
-
-      // Fake tasks with different taskIds.
-      val taskDescriptions = (1 to numTasks).map {
-        taskId => new TaskDescription(taskId, 2, "1", "TASK ${taskId}", 19,
-          1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new 
Properties, 1,
-          Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data)
-      }
-      assert(taskDescriptions.length == numTasks)
-
-      def getFakeTaskRunner(taskDescription: TaskDescription): 
Executor#TaskRunner = {
-        new executor.TaskRunner(backend, taskDescription, None) {
-          override def run(): Unit = {
-            tasksExecuted.put(taskDescription.taskId, true)
-            logInfo(s"task ${taskDescription.taskId} runs.")
-          }
-
-          override def kill(interruptThread: Boolean, reason: String): Unit = {
-            logInfo(s"task ${taskDescription.taskId} killed.")
-            tasksKilled.put(taskDescription.taskId, true)
-          }
-        }
-      }
-
-      // Feed the fake task-runners to be executed by the executor.
-      val firstLaunchTask = getFakeTaskRunner(taskDescriptions(1))
-      val otherTasks = taskDescriptions.slice(1, 
numTasks).map(getFakeTaskRunner(_)).toArray
-      assert (otherTasks.length == numTasks - 1)
-      // Workaround for compilation issue around Mockito.doReturn
-      doReturn(firstLaunchTask, otherTasks: _*).when(executor).
-        createTaskRunner(any(), any())
-
-      // Launch tasks and quickly kill them so that TaskRunner.killTask will 
be triggered.
-      taskDescriptions.foreach { taskDescription =>
-        val buffer = new 
SerializableBuffer(TaskDescription.encode(taskDescription))
-        backend.self.send(LaunchTask(buffer))
-        Thread.sleep(1)
-        backend.self.send(KillTask(taskDescription.taskId, "exec1", false, 
"test"))
-      }
-
-      eventually(timeout(10.seconds)) {
-        verify(runningTasks, times(numTasks)).put(any(), any())
-      }
-
-      assert(tasksExecuted.size == tasksKilled.size,
-        s"Tasks killed ${tasksKilled.size} != tasks executed 
${tasksExecuted.size}")
-      assert(tasksExecuted.keySet == tasksKilled.keySet)
-      logInfo(s"Task executed ${tasksExecuted.size}, task killed 
${tasksKilled.size}")
-    } finally {
-      if (backend != null) {
-        backend.rpcEnv.shutdown()
-      }
-      threadPool.shutdownNow()
-    }
-  }
-
-  /**
-   * This testcase is to verify that [[Executor.killTask()]] will always 
cancel a task even if
-   * it has not been launched yet.
-   */
-  test(s"Tasks not launched should always be cancelled.")  {
-    val conf = new SparkConf
-    val securityMgr = new SecurityManager(conf)
-    val serializer = new JavaSerializer(conf)
-    val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor")
-    var backend: CoarseGrainedExecutorBackend = null
-
-    try {
-      val rpcEnv = RpcEnv.create("1", "localhost", 0, conf, securityMgr)
-      val env = createMockEnv(conf, serializer, Some(rpcEnv))
-      backend = new CoarseGrainedExecutorBackend(env.rpcEnv, 
rpcEnv.address.hostPort, "1",
-        "host1", "host1", 4, env, None,
-        resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf))
-
-      backend.rpcEnv.setupEndpoint("Executor 1", backend)
-      backend.executor = mock[Executor](CALLS_REAL_METHODS)
-      val executor = backend.executor
-      // Mock the executor.
-      when(executor.threadPool).thenReturn(threadPool)
-      val runningTasks = spy(new ConcurrentHashMap[Long, Executor#TaskRunner])
-      when(executor.runningTasks).thenAnswer(_ => runningTasks)
-      when(executor.conf).thenReturn(conf)
-
-      // We don't really verify the data, just pass it around.
-      val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
-
-      val numTasks = 1000
-      val tasksKilled = new TrieMap[Long, Boolean]()
-      val tasksExecuted = new TrieMap[Long, Boolean]()
-
-      // Fake tasks with different taskIds.
-      val taskDescriptions = (1 to numTasks).map {
-        taskId => new TaskDescription(taskId, 2, "1", "TASK ${taskId}", 19,
-          1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new 
Properties, 1,
-          Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data)
-      }
-      assert(taskDescriptions.length == numTasks)
-
-      def getFakeTaskRunner(taskDescription: TaskDescription): 
Executor#TaskRunner = {
-        new executor.TaskRunner(backend, taskDescription, None) {
-          override def run(): Unit = {
-            tasksExecuted.put(taskDescription.taskId, true)
-            logInfo(s"task ${taskDescription.taskId} runs.")
-          }
-
-          override def kill(interruptThread: Boolean, reason: String): Unit = {
-            logInfo(s"task ${taskDescription.taskId} killed.")
-            tasksKilled.put(taskDescription.taskId, true)
-          }
-        }
-      }
-
-      // Feed the fake task-runners to be executed by the executor.
-      val firstLaunchTask = getFakeTaskRunner(taskDescriptions(1))
-      val otherTasks = taskDescriptions.slice(1, 
numTasks).map(getFakeTaskRunner(_)).toArray
-      assert (otherTasks.length == numTasks - 1)
-      // Workaround for compilation issue around Mockito.doReturn
-      doReturn(firstLaunchTask, otherTasks: _*).when(executor).
-        createTaskRunner(any(), any())
-
-      // The reverse order of events can happen when the scheduler tries to 
cancel a task right
-      // after launching it.
-      taskDescriptions.foreach { taskDescription =>
-        val buffer = new 
SerializableBuffer(TaskDescription.encode(taskDescription))
-        backend.self.send(KillTask(taskDescription.taskId, "exec1", false, 
"test"))
-        backend.self.send(LaunchTask(buffer))
-      }
-
-      eventually(timeout(10.seconds)) {
-        verify(runningTasks, times(numTasks)).put(any(), any())
-      }
-
-      assert(tasksExecuted.size == tasksKilled.size,
-        s"Tasks killed ${tasksKilled.size} != tasks executed 
${tasksExecuted.size}")
-      assert(tasksExecuted.keySet == tasksKilled.keySet)
-      logInfo(s"Task executed ${tasksExecuted.size}, task killed 
${tasksKilled.size}")
-    } finally {
-      if (backend != null) {
-        backend.rpcEnv.shutdown()
-      }
-      threadPool.shutdownNow()
-    }
-  }
-
   private def createMockEnv(conf: SparkConf, serializer: JavaSerializer,
       rpcEnv: Option[RpcEnv] = None): SparkEnv = {
     val mockEnv = mock[SparkEnv]
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala 
b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index 7f7b10c8c33..a237447b0fa 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -22,7 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler
 import java.net.URL
 import java.nio.ByteBuffer
 import java.util.Properties
-import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, TimeUnit}
 import java.util.concurrent.atomic.AtomicBoolean
 
 import scala.collection.immutable
@@ -321,7 +321,13 @@ class ExecutorSuite extends SparkFunSuite
       nonZeroAccumulator.add(1)
       metrics.registerAccumulator(nonZeroAccumulator)
 
-      val tasksMap = executor.runningTasks
+      val executorClass = classOf[Executor]
+      val tasksMap = {
+        val field =
+          
executorClass.getDeclaredField("org$apache$spark$executor$Executor$$runningTasks")
+        field.setAccessible(true)
+        field.get(executor).asInstanceOf[ConcurrentHashMap[Long, 
executor.TaskRunner]]
+      }
       val mockTaskRunner = mock[executor.TaskRunner]
       val mockTask = mock[Task[Any]]
       when(mockTask.metrics).thenReturn(metrics)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to