Repository: spark
Updated Branches:
  refs/heads/master cf9367826 -> 1c9a386c6


[SPARK-13602][CORE] Add shutdown hook to DriverRunner to prevent driver process 
leak

## What changes were proposed in this pull request?

Added shutdown hook to DriverRunner to kill the driver process in case the 
Worker JVM exits suddenly and the `WorkerWatcher` was unable to properly catch 
this.  Did some cleanup to consolidate driver state management and setting of 
finalized vars within the running thread.

## How was this patch tested?

Added unit tests to verify that final state and exception variables are set 
accordingly for successfull, failed, and errors in the driver process.  
Retrofitted existing test to verify killing of mocked process ends with the 
correct state and stops properly

Manually tested (with deploy-mode=cluster) that the shutdown hook is called by 
forcibly exiting the `Worker` and various points in the code with the 
`WorkerWatcher` both disabled and enabled.  Also, manually killed the driver 
through the ui and verified that the `DriverRunner` interrupted, killed the 
process and exited properly.

Author: Bryan Cutler <[email protected]>

Closes #11746 from BryanCutler/DriverRunner-shutdown-hook-SPARK-13602.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1c9a386c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1c9a386c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1c9a386c

Branch: refs/heads/master
Commit: 1c9a386c6b6812a3931f3fb0004249894a01f657
Parents: cf93678
Author: Bryan Cutler <[email protected]>
Authored: Thu Aug 11 14:49:11 2016 -0700
Committer: Marcelo Vanzin <[email protected]>
Committed: Thu Aug 11 14:49:11 2016 -0700

----------------------------------------------------------------------
 .../spark/deploy/worker/DriverRunner.scala      | 119 +++++++++++--------
 .../spark/deploy/worker/DriverRunnerTest.scala  |  73 +++++++++++-
 2 files changed, 142 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1c9a386c/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala 
b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index f4376de..289b0b9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -32,7 +32,7 @@ import org.apache.spark.deploy.master.DriverState
 import org.apache.spark.deploy.master.DriverState.DriverState
 import org.apache.spark.internal.Logging
 import org.apache.spark.rpc.RpcEndpointRef
-import org.apache.spark.util.{Clock, SystemClock, Utils}
+import org.apache.spark.util.{Clock, ShutdownHookManager, SystemClock, Utils}
 
 /**
  * Manages the execution of one driver, including automatically restarting the 
driver on failure.
@@ -53,9 +53,11 @@ private[deploy] class DriverRunner(
   @volatile private var killed = false
 
   // Populated once finished
-  private[worker] var finalState: Option[DriverState] = None
-  private[worker] var finalException: Option[Exception] = None
-  private var finalExitCode: Option[Int] = None
+  @volatile private[worker] var finalState: Option[DriverState] = None
+  @volatile private[worker] var finalException: Option[Exception] = None
+
+  // Timeout to wait for when trying to terminate a driver.
+  private val DRIVER_TERMINATE_TIMEOUT_MS = 10 * 1000
 
   // Decoupled for testing
   def setClock(_clock: Clock): Unit = {
@@ -78,49 +80,53 @@ private[deploy] class DriverRunner(
   private[worker] def start() = {
     new Thread("DriverRunner for " + driverId) {
       override def run() {
+        var shutdownHook: AnyRef = null
         try {
-          val driverDir = createWorkingDirectory()
-          val localJarFilename = downloadUserJar(driverDir)
-
-          def substituteVariables(argument: String): String = argument match {
-            case "{{WORKER_URL}}" => workerUrl
-            case "{{USER_JAR}}" => localJarFilename
-            case other => other
+          shutdownHook = ShutdownHookManager.addShutdownHook { () =>
+            logInfo(s"Worker shutting down, killing driver $driverId")
+            kill()
           }
 
-          // TODO: If we add ability to submit multiple jars they should also 
be added here
-          val builder = CommandUtils.buildProcessBuilder(driverDesc.command, 
securityManager,
-            driverDesc.mem, sparkHome.getAbsolutePath, substituteVariables)
-          launchDriver(builder, driverDir, driverDesc.supervise)
-        }
-        catch {
-          case e: Exception => finalException = Some(e)
-        }
+          // prepare driver jars and run driver
+          val exitCode = prepareAndRunDriver()
 
-        val state =
-          if (killed) {
-            DriverState.KILLED
-          } else if (finalException.isDefined) {
-            DriverState.ERROR
+          // set final state depending on if forcibly killed and process exit 
code
+          finalState = if (exitCode == 0) {
+            Some(DriverState.FINISHED)
+          } else if (killed) {
+            Some(DriverState.KILLED)
           } else {
-            finalExitCode match {
-              case Some(0) => DriverState.FINISHED
-              case _ => DriverState.FAILED
-            }
+            Some(DriverState.FAILED)
           }
+        } catch {
+          case e: Exception =>
+            kill()
+            finalState = Some(DriverState.ERROR)
+            finalException = Some(e)
+        } finally {
+          if (shutdownHook != null) {
+            ShutdownHookManager.removeShutdownHook(shutdownHook)
+          }
+        }
 
-        finalState = Some(state)
-
-        worker.send(DriverStateChanged(driverId, state, finalException))
+        // notify worker of final driver state, possible exception
+        worker.send(DriverStateChanged(driverId, finalState.get, 
finalException))
       }
     }.start()
   }
 
   /** Terminate this driver (or prevent it from ever starting if not yet 
started) */
-  private[worker] def kill() {
+  private[worker] def kill(): Unit = {
+    logInfo("Killing driver process!")
+    killed = true
     synchronized {
-      process.foreach(_.destroy())
-      killed = true
+      process.foreach { p =>
+        val exitCode = Utils.terminateProcess(p, DRIVER_TERMINATE_TIMEOUT_MS)
+        if (exitCode.isEmpty) {
+          logWarning("Failed to terminate driver process: " + p +
+              ". This process will likely be orphaned.")
+        }
+      }
     }
   }
 
@@ -142,7 +148,6 @@ private[deploy] class DriverRunner(
    */
   private def downloadUserJar(driverDir: File): String = {
     val jarPath = new Path(driverDesc.jarUrl)
-
     val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
     val destPath = new File(driverDir.getAbsolutePath, jarPath.getName)
     val jarFileName = jarPath.getName
@@ -168,7 +173,24 @@ private[deploy] class DriverRunner(
     localJarFilename
   }
 
-  private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: 
Boolean) {
+  private[worker] def prepareAndRunDriver(): Int = {
+    val driverDir = createWorkingDirectory()
+    val localJarFilename = downloadUserJar(driverDir)
+
+    def substituteVariables(argument: String): String = argument match {
+      case "{{WORKER_URL}}" => workerUrl
+      case "{{USER_JAR}}" => localJarFilename
+      case other => other
+    }
+
+    // TODO: If we add ability to submit multiple jars they should also be 
added here
+    val builder = CommandUtils.buildProcessBuilder(driverDesc.command, 
securityManager,
+      driverDesc.mem, sparkHome.getAbsolutePath, substituteVariables)
+
+    runDriver(builder, driverDir, driverDesc.supervise)
+  }
+
+  private def runDriver(builder: ProcessBuilder, baseDir: File, supervise: 
Boolean): Int = {
     builder.directory(baseDir)
     def initialize(process: Process): Unit = {
       // Redirect stdout and stderr to files
@@ -184,39 +206,40 @@ private[deploy] class DriverRunner(
     runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
   }
 
-  def runCommandWithRetry(
-      command: ProcessBuilderLike, initialize: Process => Unit, supervise: 
Boolean): Unit = {
+  private[worker] def runCommandWithRetry(
+      command: ProcessBuilderLike, initialize: Process => Unit, supervise: 
Boolean): Int = {
+    var exitCode = -1
     // Time to wait between submission retries.
     var waitSeconds = 1
     // A run of this many seconds resets the exponential back-off.
     val successfulRunDuration = 5
-
     var keepTrying = !killed
 
     while (keepTrying) {
       logInfo("Launch Command: " + command.command.mkString("\"", "\" \"", 
"\""))
 
       synchronized {
-        if (killed) { return }
+        if (killed) { return exitCode }
         process = Some(command.start())
         initialize(process.get)
       }
 
       val processStart = clock.getTimeMillis()
-      val exitCode = process.get.waitFor()
-      if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) 
{
-        waitSeconds = 1
-      }
+      exitCode = process.get.waitFor()
 
-      if (supervise && exitCode != 0 && !killed) {
+      // check if attempting another run
+      keepTrying = supervise && exitCode != 0 && !killed
+      if (keepTrying) {
+        if (clock.getTimeMillis() - processStart > successfulRunDuration * 
1000) {
+          waitSeconds = 1
+        }
         logInfo(s"Command exited with status $exitCode, re-launching after 
$waitSeconds s.")
         sleeper.sleep(waitSeconds)
         waitSeconds = waitSeconds * 2 // exponential back-off
       }
-
-      keepTrying = supervise && exitCode != 0 && !killed
-      finalExitCode = Some(exitCode)
     }
+
+    exitCode
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1c9a386c/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala 
b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
index 2a1696b..5295604 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
@@ -19,13 +19,18 @@ package org.apache.spark.deploy.worker
 
 import java.io.File
 
+import scala.concurrent.duration._
+
 import org.mockito.Matchers._
 import org.mockito.Mockito._
 import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
+import org.scalatest.concurrent.Eventually.{eventually, interval, timeout}
 
 import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
 import org.apache.spark.deploy.{Command, DriverDescription}
+import org.apache.spark.deploy.master.DriverState
+import org.apache.spark.rpc.RpcEndpointRef
 import org.apache.spark.util.Clock
 
 class DriverRunnerTest extends SparkFunSuite {
@@ -33,8 +38,10 @@ class DriverRunnerTest extends SparkFunSuite {
     val command = new Command("mainClass", Seq(), Map(), Seq(), Seq(), Seq())
     val driverDescription = new DriverDescription("jarUrl", 512, 1, true, 
command)
     val conf = new SparkConf()
-    new DriverRunner(conf, "driverId", new File("workDir"), new 
File("sparkHome"),
-      driverDescription, null, "spark://1.2.3.4/worker/", new 
SecurityManager(conf))
+    val worker = mock(classOf[RpcEndpointRef])
+    doNothing().when(worker).send(any())
+    spy(new DriverRunner(conf, "driverId", new File("workDir"), new 
File("sparkHome"),
+      driverDescription, worker, "spark://1.2.3.4/worker/", new 
SecurityManager(conf)))
   }
 
   private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) 
= {
@@ -45,6 +52,19 @@ class DriverRunnerTest extends SparkFunSuite {
     (processBuilder, process)
   }
 
+  private def createTestableDriverRunner(
+      processBuilder: ProcessBuilderLike,
+      superviseRetry: Boolean) = {
+    val runner = createDriverRunner()
+    runner.setSleeper(mock(classOf[Sleeper]))
+    doAnswer(new Answer[Int] {
+      def answer(invocation: InvocationOnMock): Int = {
+        runner.runCommandWithRetry(processBuilder, p => (), supervise = 
superviseRetry)
+      }
+    }).when(runner).prepareAndRunDriver()
+    runner
+  }
+
   test("Process succeeds instantly") {
     val runner = createDriverRunner()
 
@@ -145,4 +165,53 @@ class DriverRunnerTest extends SparkFunSuite {
     verify(sleeper, times(2)).sleep(2)
   }
 
+  test("Kill process finalized with state KILLED") {
+    val (processBuilder, process) = createProcessBuilderAndProcess()
+    val runner = createTestableDriverRunner(processBuilder, superviseRetry = 
true)
+
+    when(process.waitFor()).thenAnswer(new Answer[Int] {
+      def answer(invocation: InvocationOnMock): Int = {
+        runner.kill()
+        -1
+      }
+    })
+
+    runner.start()
+
+    eventually(timeout(10.seconds), interval(100.millis)) {
+      assert(runner.finalState.get === DriverState.KILLED)
+    }
+    verify(process, times(1)).waitFor()
+  }
+
+  test("Finalized with state FINISHED") {
+    val (processBuilder, process) = createProcessBuilderAndProcess()
+    val runner = createTestableDriverRunner(processBuilder, superviseRetry = 
true)
+    when(process.waitFor()).thenReturn(0)
+    runner.start()
+    eventually(timeout(10.seconds), interval(100.millis)) {
+      assert(runner.finalState.get === DriverState.FINISHED)
+    }
+  }
+
+  test("Finalized with state FAILED") {
+    val (processBuilder, process) = createProcessBuilderAndProcess()
+    val runner = createTestableDriverRunner(processBuilder, superviseRetry = 
false)
+    when(process.waitFor()).thenReturn(-1)
+    runner.start()
+    eventually(timeout(10.seconds), interval(100.millis)) {
+      assert(runner.finalState.get === DriverState.FAILED)
+    }
+  }
+
+  test("Handle exception starting process") {
+    val (processBuilder, process) = createProcessBuilderAndProcess()
+    val runner = createTestableDriverRunner(processBuilder, superviseRetry = 
false)
+    when(processBuilder.start()).thenThrow(new NullPointerException("bad 
command list"))
+    runner.start()
+    eventually(timeout(10.seconds), interval(100.millis)) {
+      assert(runner.finalState.get === DriverState.ERROR)
+      assert(runner.finalException.get.isInstanceOf[RuntimeException])
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to