Repository: spark
Updated Branches:
  refs/heads/master 47c73d410 -> 1d5663e92


[SPARK-5760][SPARK-5761] Fix standalone rest protocol corner cases + revamp 
tests

The changes are summarized in the commit message. Test or test-related code 
accounts for 90% of the lines changed.

Author: Andrew Or <and...@databricks.com>

Closes #4557 from andrewor14/rest-tests and squashes the following commits:

b4dc980 [Andrew Or] Merge branch 'master' of github.com:apache/spark into 
rest-tests
b55e40f [Andrew Or] Add test for unknown fields
cc96993 [Andrew Or] private[spark] -> private[rest]
578cf45 [Andrew Or] Clean up test code a little
d82d971 [Andrew Or] v1 -> serverVersion
ea48f65 [Andrew Or] Merge branch 'master' of github.com:apache/spark into 
rest-tests
00999a8 [Andrew Or] Revamp tests + fix a few corner cases


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1d5663e9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1d5663e9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1d5663e9

Branch: refs/heads/master
Commit: 1d5663e92cdaaa3dabfa58fdd7aede7e4fa4ec63
Parents: 47c73d4
Author: Andrew Or <and...@databricks.com>
Authored: Thu Feb 12 14:47:52 2015 -0800
Committer: Andrew Or <and...@databricks.com>
Committed: Thu Feb 12 14:47:52 2015 -0800

----------------------------------------------------------------------
 .../deploy/rest/StandaloneRestClient.scala      |  52 +-
 .../deploy/rest/StandaloneRestServer.scala      | 105 ++-
 .../deploy/rest/StandaloneRestSubmitSuite.scala | 671 ++++++++++++++-----
 3 files changed, 589 insertions(+), 239 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1d5663e9/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala 
b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala
index 115aa52..c4be1f1 100644
--- 
a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala
+++ 
b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala
@@ -19,10 +19,11 @@ package org.apache.spark.deploy.rest
 
 import java.io.{DataOutputStream, FileNotFoundException}
 import java.net.{HttpURLConnection, SocketException, URL}
+import javax.servlet.http.HttpServletResponse
 
 import scala.io.Source
 
-import com.fasterxml.jackson.databind.JsonMappingException
+import com.fasterxml.jackson.core.JsonProcessingException
 import com.google.common.base.Charsets
 
 import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
@@ -155,10 +156,21 @@ private[spark] class StandaloneRestClient extends Logging 
{
   /**
    * Read the response from the server and return it as a validated 
[[SubmitRestProtocolResponse]].
    * If the response represents an error, report the embedded message to the 
user.
+   * Exposed for testing.
    */
-  private def readResponse(connection: HttpURLConnection): 
SubmitRestProtocolResponse = {
+  private[rest] def readResponse(connection: HttpURLConnection): 
SubmitRestProtocolResponse = {
     try {
-      val responseJson = 
Source.fromInputStream(connection.getInputStream).mkString
+      val dataStream =
+        if (connection.getResponseCode == HttpServletResponse.SC_OK) {
+          connection.getInputStream
+        } else {
+          connection.getErrorStream
+        }
+      // If the server threw an exception while writing a response, it will 
not have a body
+      if (dataStream == null) {
+        throw new SubmitRestProtocolException("Server returned empty body")
+      }
+      val responseJson = Source.fromInputStream(dataStream).mkString
       logDebug(s"Response from the server:\n$responseJson")
       val response = SubmitRestProtocolMessage.fromJson(responseJson)
       response.validate()
@@ -177,7 +189,7 @@ private[spark] class StandaloneRestClient extends Logging {
       case unreachable @ (_: FileNotFoundException | _: SocketException) =>
         throw new SubmitRestConnectionException(
           s"Unable to connect to server ${connection.getURL}", unreachable)
-      case malformed @ (_: SubmitRestProtocolException | _: 
JsonMappingException) =>
+      case malformed @ (_: JsonProcessingException | _: 
SubmitRestProtocolException) =>
         throw new SubmitRestProtocolException(
           "Malformed response received from server", malformed)
     }
@@ -284,7 +296,27 @@ private[spark] object StandaloneRestClient {
   val REPORT_DRIVER_STATUS_MAX_TRIES = 10
   val PROTOCOL_VERSION = "v1"
 
-  /** Submit an application, assuming Spark parameters are specified through 
system properties. */
+  /**
+   * Submit an application, assuming Spark parameters are specified through 
the given config.
+   * This is abstracted to its own method for testing purposes.
+   */
+  private[rest] def run(
+      appResource: String,
+      mainClass: String,
+      appArgs: Array[String],
+      conf: SparkConf,
+      env: Map[String, String] = sys.env): SubmitRestProtocolResponse = {
+    val master = conf.getOption("spark.master").getOrElse {
+      throw new IllegalArgumentException("'spark.master' must be set.")
+    }
+    val sparkProperties = conf.getAll.toMap
+    val environmentVariables = env.filter { case (k, _) => 
k.startsWith("SPARK_") }
+    val client = new StandaloneRestClient
+    val submitRequest = client.constructSubmitRequest(
+      appResource, mainClass, appArgs, sparkProperties, environmentVariables)
+    client.createSubmission(master, submitRequest)
+  }
+
   def main(args: Array[String]): Unit = {
     if (args.size < 2) {
       sys.error("Usage: StandaloneRestClient [app resource] [main class] [app 
args*]")
@@ -294,14 +326,6 @@ private[spark] object StandaloneRestClient {
     val mainClass = args(1)
     val appArgs = args.slice(2, args.size)
     val conf = new SparkConf
-    val master = conf.getOption("spark.master").getOrElse {
-      throw new IllegalArgumentException("'spark.master' must be set.")
-    }
-    val sparkProperties = conf.getAll.toMap
-    val environmentVariables = sys.env.filter { case (k, _) => 
k.startsWith("SPARK_") }
-    val client = new StandaloneRestClient
-    val submitRequest = client.constructSubmitRequest(
-      appResource, mainClass, appArgs, sparkProperties, environmentVariables)
-    client.createSubmission(master, submitRequest)
+    run(appResource, mainClass, appArgs, conf)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1d5663e9/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala 
b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
index acd3a2b..f9e0478 100644
--- 
a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
+++ 
b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
@@ -17,15 +17,14 @@
 
 package org.apache.spark.deploy.rest
 
-import java.io.{DataOutputStream, File}
+import java.io.File
 import java.net.InetSocketAddress
 import javax.servlet.http.{HttpServlet, HttpServletRequest, 
HttpServletResponse}
 
 import scala.io.Source
 
 import akka.actor.ActorRef
-import com.fasterxml.jackson.databind.JsonMappingException
-import com.google.common.base.Charsets
+import com.fasterxml.jackson.core.JsonProcessingException
 import org.eclipse.jetty.server.Server
 import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler}
 import org.eclipse.jetty.util.thread.QueuedThreadPool
@@ -70,14 +69,14 @@ private[spark] class StandaloneRestServer(
   import StandaloneRestServer._
 
   private var _server: Option[Server] = None
-  private val baseContext = s"/$PROTOCOL_VERSION/submissions"
-
-  // A mapping from servlets to the URL prefixes they are responsible for
-  private val servletToContext = Map[StandaloneRestServlet, String](
-    new SubmitRequestServlet(masterActor, masterUrl, masterConf) -> 
s"$baseContext/create/*",
-    new KillRequestServlet(masterActor, masterConf) -> s"$baseContext/kill/*",
-    new StatusRequestServlet(masterActor, masterConf) -> 
s"$baseContext/status/*",
-    new ErrorServlet -> "/*" // default handler
+
+  // A mapping from URL prefixes to servlets that serve them. Exposed for 
testing.
+  protected val baseContext = s"/$PROTOCOL_VERSION/submissions"
+  protected val contextToServlet = Map[String, StandaloneRestServlet](
+    s"$baseContext/create/*" -> new SubmitRequestServlet(masterActor, 
masterUrl, masterConf),
+    s"$baseContext/kill/*" -> new KillRequestServlet(masterActor, masterConf),
+    s"$baseContext/status/*" -> new StatusRequestServlet(masterActor, 
masterConf),
+    "/*" -> new ErrorServlet // default handler
   )
 
   /** Start the server and return the bound port. */
@@ -99,7 +98,7 @@ private[spark] class StandaloneRestServer(
     server.setThreadPool(threadPool)
     val mainHandler = new ServletContextHandler
     mainHandler.setContextPath("/")
-    servletToContext.foreach { case (servlet, prefix) =>
+    contextToServlet.foreach { case (prefix, servlet) =>
       mainHandler.addServlet(new ServletHolder(servlet), prefix)
     }
     server.setHandler(mainHandler)
@@ -113,7 +112,7 @@ private[spark] class StandaloneRestServer(
   }
 }
 
-private object StandaloneRestServer {
+private[rest] object StandaloneRestServer {
   val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION
   val SC_UNKNOWN_PROTOCOL_VERSION = 468
 }
@@ -121,20 +120,7 @@ private object StandaloneRestServer {
 /**
  * An abstract servlet for handling requests passed to the 
[[StandaloneRestServer]].
  */
-private abstract class StandaloneRestServlet extends HttpServlet with Logging {
-
-  /** Service a request. If an exception is thrown in the process, indicate 
server error. */
-  protected override def service(
-      request: HttpServletRequest,
-      response: HttpServletResponse): Unit = {
-    try {
-      super.service(request, response)
-    } catch {
-      case e: Exception =>
-        logError("Exception while handling request", e)
-        response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
-    }
-  }
+private[rest] abstract class StandaloneRestServlet extends HttpServlet with 
Logging {
 
   /**
    * Serialize the given response message to JSON and send it through the 
response servlet.
@@ -146,11 +132,7 @@ private abstract class StandaloneRestServlet extends 
HttpServlet with Logging {
     val message = validateResponse(responseMessage, responseServlet)
     responseServlet.setContentType("application/json")
     responseServlet.setCharacterEncoding("utf-8")
-    responseServlet.setStatus(HttpServletResponse.SC_OK)
-    val content = message.toJson.getBytes(Charsets.UTF_8)
-    val out = new DataOutputStream(responseServlet.getOutputStream)
-    out.write(content)
-    out.close()
+    responseServlet.getWriter.write(message.toJson)
   }
 
   /**
@@ -187,6 +169,19 @@ private abstract class StandaloneRestServlet extends 
HttpServlet with Logging {
   }
 
   /**
+   * Parse a submission ID from the relative path, assuming it is the first 
part of the path.
+   * For instance, we expect the path to take the form /[submission 
ID]/maybe/something/else.
+   * The returned submission ID cannot be empty. If the path is unexpected, 
return None.
+   */
+  protected def parseSubmissionId(path: String): Option[String] = {
+    if (path == null || path.isEmpty) {
+      None
+    } else {
+      path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty)
+    }
+  }
+
+  /**
    * Validate the response to ensure that it is correctly constructed.
    *
    * If it is, simply return the message as is. Otherwise, return an error 
response instead
@@ -209,7 +204,7 @@ private abstract class StandaloneRestServlet extends 
HttpServlet with Logging {
 /**
  * A servlet for handling kill requests passed to the [[StandaloneRestServer]].
  */
-private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
+private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
   extends StandaloneRestServlet {
 
   /**
@@ -219,18 +214,15 @@ private class KillRequestServlet(masterActor: ActorRef, 
conf: SparkConf)
   protected override def doPost(
       request: HttpServletRequest,
       response: HttpServletResponse): Unit = {
-    val submissionId = request.getPathInfo.stripPrefix("/")
-    val responseMessage =
-      if (submissionId.nonEmpty) {
-        handleKill(submissionId)
-      } else {
-        response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
-        handleError("Submission ID is missing in kill request.")
-      }
+    val submissionId = parseSubmissionId(request.getPathInfo)
+    val responseMessage = submissionId.map(handleKill).getOrElse {
+      response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
+      handleError("Submission ID is missing in kill request.")
+    }
     sendResponse(responseMessage, response)
   }
 
-  private def handleKill(submissionId: String): KillSubmissionResponse = {
+  protected def handleKill(submissionId: String): KillSubmissionResponse = {
     val askTimeout = AkkaUtils.askTimeout(conf)
     val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse](
       DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout)
@@ -246,7 +238,7 @@ private class KillRequestServlet(masterActor: ActorRef, 
conf: SparkConf)
 /**
  * A servlet for handling status requests passed to the 
[[StandaloneRestServer]].
  */
-private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
+private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: 
SparkConf)
   extends StandaloneRestServlet {
 
   /**
@@ -256,18 +248,15 @@ private class StatusRequestServlet(masterActor: ActorRef, 
conf: SparkConf)
   protected override def doGet(
       request: HttpServletRequest,
       response: HttpServletResponse): Unit = {
-    val submissionId = request.getPathInfo.stripPrefix("/")
-    val responseMessage =
-      if (submissionId.nonEmpty) {
-        handleStatus(submissionId)
-      } else {
-        response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
-        handleError("Submission ID is missing in status request.")
-      }
+    val submissionId = parseSubmissionId(request.getPathInfo)
+    val responseMessage = submissionId.map(handleStatus).getOrElse {
+      response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
+      handleError("Submission ID is missing in status request.")
+    }
     sendResponse(responseMessage, response)
   }
 
-  private def handleStatus(submissionId: String): SubmissionStatusResponse = {
+  protected def handleStatus(submissionId: String): SubmissionStatusResponse = 
{
     val askTimeout = AkkaUtils.askTimeout(conf)
     val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse](
       DeployMessages.RequestDriverStatus(submissionId), masterActor, 
askTimeout)
@@ -287,7 +276,7 @@ private class StatusRequestServlet(masterActor: ActorRef, 
conf: SparkConf)
 /**
  * A servlet for handling submit requests passed to the 
[[StandaloneRestServer]].
  */
-private class SubmitRequestServlet(
+private[rest] class SubmitRequestServlet(
     masterActor: ActorRef,
     masterUrl: String,
     conf: SparkConf)
@@ -313,7 +302,7 @@ private class SubmitRequestServlet(
         handleSubmit(requestMessageJson, requestMessage, responseServlet)
       } catch {
         // The client failed to provide a valid JSON, so this is not our fault
-        case e @ (_: JsonMappingException | _: SubmitRestProtocolException) =>
+        case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) 
=>
           responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST)
           handleError("Malformed request: " + formatException(e))
       }
@@ -413,7 +402,7 @@ private class ErrorServlet extends StandaloneRestServlet {
       request: HttpServletRequest,
       response: HttpServletResponse): Unit = {
     val path = request.getPathInfo
-    val parts = path.stripPrefix("/").split("/").toSeq
+    val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList
     var versionMismatch = false
     var msg =
       parts match {
@@ -423,10 +412,10 @@ private class ErrorServlet extends StandaloneRestServlet {
         case `serverVersion` :: Nil =>
           // http://host:port/correct-version
           "Missing the /submissions prefix."
-        case `serverVersion` :: "submissions" :: Nil =>
-          // http://host:port/correct-version/submissions
+        case `serverVersion` :: "submissions" :: tail =>
+          // http://host:port/correct-version/submissions/*
           "Missing an action: please specify one of /create, /kill, or 
/status."
-        case unknownVersion :: _ =>
+        case unknownVersion :: tail =>
           // http://host:port/unknown-version/*
           versionMismatch = true
           s"Unknown protocol version '$unknownVersion'."

http://git-wip-us.apache.org/repos/asf/spark/blob/1d5663e9/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
 
b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
index 29aed89..a345e06 100644
--- 
a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
@@ -17,141 +17,412 @@
 
 package org.apache.spark.deploy.rest
 
-import java.io.{File, FileInputStream, FileOutputStream, PrintWriter}
-import java.util.jar.{JarEntry, JarOutputStream}
-import java.util.zip.ZipEntry
+import java.io.DataOutputStream
+import java.net.{HttpURLConnection, URL}
+import javax.servlet.http.HttpServletResponse
 
-import scala.collection.mutable.ArrayBuffer
-import scala.io.Source
+import scala.collection.mutable
 
-import akka.actor.ActorSystem
-import com.google.common.io.ByteStreams
-import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
-import org.scalatest.exceptions.TestFailedException
+import akka.actor.{Actor, ActorRef, ActorSystem, Props}
+import com.google.common.base.Charsets
+import org.scalatest.{BeforeAndAfterEach, FunSuite}
+import org.json4s.JsonAST._
+import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.deploy.DeployMessages._
 import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments}
-import org.apache.spark.deploy.master.{DriverState, Master}
-import org.apache.spark.deploy.worker.Worker
+import org.apache.spark.deploy.master.DriverState._
 
 /**
- * End-to-end tests for the REST application submission protocol in standalone 
mode.
+ * Tests for the REST application submission protocol used in standalone 
cluster mode.
  */
-class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterAll with 
BeforeAndAfterEach {
-  private val systemsToStop = new ArrayBuffer[ActorSystem]
-  private val masterRestUrl = startLocalCluster()
+class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach {
   private val client = new StandaloneRestClient
-  private val mainJar = StandaloneRestSubmitSuite.createJar()
-  private val mainClass = StandaloneRestApp.getClass.getName.stripSuffix("$")
+  private var actorSystem: Option[ActorSystem] = None
+  private var server: Option[StandaloneRestServer] = None
 
-  override def afterAll() {
-    systemsToStop.foreach(_.shutdown())
+  override def afterEach() {
+    actorSystem.foreach(_.shutdown())
+    server.foreach(_.stop())
   }
 
-  test("simple submit until completion") {
-    val resultsFile = File.createTempFile("test-submit", ".txt")
-    val numbers = Seq(1, 2, 3)
-    val size = 500
-    val submissionId = submitApplication(resultsFile, numbers, size)
-    waitUntilFinished(submissionId)
-    validateResult(resultsFile, numbers, size)
+  test("construct submit request") {
+    val appArgs = Array("one", "two", "three")
+    val sparkProperties = Map("spark.app.name" -> "pi")
+    val environmentVariables = Map("SPARK_ONE" -> "UN", "SPARK_TWO" -> "DEUX")
+    val request = client.constructSubmitRequest(
+      "my-app-resource", "my-main-class", appArgs, sparkProperties, 
environmentVariables)
+    assert(request.action === Utils.getFormattedClassName(request))
+    assert(request.clientSparkVersion === SPARK_VERSION)
+    assert(request.appResource === "my-app-resource")
+    assert(request.mainClass === "my-main-class")
+    assert(request.appArgs === appArgs)
+    assert(request.sparkProperties === sparkProperties)
+    assert(request.environmentVariables === environmentVariables)
   }
 
-  test("kill empty submission") {
-    val response = client.killSubmission(masterRestUrl, 
"submission-that-does-not-exist")
-    val killResponse = getKillResponse(response)
-    val killSuccess = killResponse.success
-    assert(!killSuccess)
+  test("create submission") {
+    val submittedDriverId = "my-driver-id"
+    val submitMessage = "your driver is submitted"
+    val masterUrl = startDummyServer(submitId = submittedDriverId, 
submitMessage = submitMessage)
+    val appArgs = Array("one", "two", "four")
+    val request = constructSubmitRequest(masterUrl, appArgs)
+    assert(request.appArgs === appArgs)
+    assert(request.sparkProperties("spark.master") === masterUrl)
+    val response = client.createSubmission(masterUrl, request)
+    val submitResponse = getSubmitResponse(response)
+    assert(submitResponse.action === 
Utils.getFormattedClassName(submitResponse))
+    assert(submitResponse.serverSparkVersion === SPARK_VERSION)
+    assert(submitResponse.message === submitMessage)
+    assert(submitResponse.submissionId === submittedDriverId)
+    assert(submitResponse.success)
+  }
+
+  test("create submission from main method") {
+    val submittedDriverId = "your-driver-id"
+    val submitMessage = "my driver is submitted"
+    val masterUrl = startDummyServer(submitId = submittedDriverId, 
submitMessage = submitMessage)
+    val conf = new SparkConf(loadDefaults = false)
+    conf.set("spark.master", masterUrl)
+    conf.set("spark.app.name", "dreamer")
+    val appArgs = Array("one", "two", "six")
+    // main method calls this
+    val response = StandaloneRestClient.run("app-resource", "main-class", 
appArgs, conf)
+    val submitResponse = getSubmitResponse(response)
+    assert(submitResponse.action === 
Utils.getFormattedClassName(submitResponse))
+    assert(submitResponse.serverSparkVersion === SPARK_VERSION)
+    assert(submitResponse.message === submitMessage)
+    assert(submitResponse.submissionId === submittedDriverId)
+    assert(submitResponse.success)
   }
 
-  test("kill running submission") {
-    val resultsFile = File.createTempFile("test-kill", ".txt")
-    val numbers = Seq(1, 2, 3)
-    val size = 500
-    val submissionId = submitApplication(resultsFile, numbers, size)
-    val response = client.killSubmission(masterRestUrl, submissionId)
+  test("kill submission") {
+    val submissionId = "my-lyft-driver"
+    val killMessage = "your driver is killed"
+    val masterUrl = startDummyServer(killMessage = killMessage)
+    val response = client.killSubmission(masterUrl, submissionId)
     val killResponse = getKillResponse(response)
-    val killSuccess = killResponse.success
-    waitUntilFinished(submissionId)
-    val response2 = client.requestSubmissionStatus(masterRestUrl, submissionId)
-    val statusResponse = getStatusResponse(response2)
-    val statusSuccess = statusResponse.success
-    val driverState = statusResponse.driverState
-    assert(killSuccess)
-    assert(statusSuccess)
-    assert(driverState === DriverState.KILLED.toString)
-    // we should not see the expected results because we killed the submission
-    intercept[TestFailedException] { validateResult(resultsFile, numbers, 
size) }
+    assert(killResponse.action === Utils.getFormattedClassName(killResponse))
+    assert(killResponse.serverSparkVersion === SPARK_VERSION)
+    assert(killResponse.message === killMessage)
+    assert(killResponse.submissionId === submissionId)
+    assert(killResponse.success)
   }
 
-  test("request status for empty submission") {
-    val response = client.requestSubmissionStatus(masterRestUrl, 
"submission-that-does-not-exist")
+  test("request submission status") {
+    val submissionId = "my-uber-driver"
+    val submissionState = KILLED
+    val submissionException = new Exception("there was an irresponsible mix of 
alcohol and cars")
+    val masterUrl = startDummyServer(state = submissionState, exception = 
Some(submissionException))
+    val response = client.requestSubmissionStatus(masterUrl, submissionId)
     val statusResponse = getStatusResponse(response)
-    val statusSuccess = statusResponse.success
-    assert(!statusSuccess)
+    assert(statusResponse.action === 
Utils.getFormattedClassName(statusResponse))
+    assert(statusResponse.serverSparkVersion === SPARK_VERSION)
+    assert(statusResponse.message.contains(submissionException.getMessage))
+    assert(statusResponse.submissionId === submissionId)
+    assert(statusResponse.driverState === submissionState.toString)
+    assert(statusResponse.success)
+  }
+
+  test("create then kill") {
+    val masterUrl = startSmartServer()
+    val request = constructSubmitRequest(masterUrl)
+    val response1 = client.createSubmission(masterUrl, request)
+    val submitResponse = getSubmitResponse(response1)
+    assert(submitResponse.success)
+    assert(submitResponse.submissionId != null)
+    // kill submission that was just created
+    val submissionId = submitResponse.submissionId
+    val response2 = client.killSubmission(masterUrl, submissionId)
+    val killResponse = getKillResponse(response2)
+    assert(killResponse.success)
+    assert(killResponse.submissionId === submissionId)
+  }
+
+  test("create then request status") {
+    val masterUrl = startSmartServer()
+    val request = constructSubmitRequest(masterUrl)
+    val response1 = client.createSubmission(masterUrl, request)
+    val submitResponse = getSubmitResponse(response1)
+    assert(submitResponse.success)
+    assert(submitResponse.submissionId != null)
+    // request status of submission that was just created
+    val submissionId = submitResponse.submissionId
+    val response2 = client.requestSubmissionStatus(masterUrl, submissionId)
+    val statusResponse = getStatusResponse(response2)
+    assert(statusResponse.success)
+    assert(statusResponse.submissionId === submissionId)
+    assert(statusResponse.driverState === RUNNING.toString)
+  }
+
+  test("create then kill then request status") {
+    val masterUrl = startSmartServer()
+    val request = constructSubmitRequest(masterUrl)
+    val response1 = client.createSubmission(masterUrl, request)
+    val response2 = client.createSubmission(masterUrl, request)
+    val submitResponse1 = getSubmitResponse(response1)
+    val submitResponse2 = getSubmitResponse(response2)
+    assert(submitResponse1.success)
+    assert(submitResponse2.success)
+    assert(submitResponse1.submissionId != null)
+    assert(submitResponse2.submissionId != null)
+    val submissionId1 = submitResponse1.submissionId
+    val submissionId2 = submitResponse2.submissionId
+    // kill only submission 1, but not submission 2
+    val response3 = client.killSubmission(masterUrl, submissionId1)
+    val killResponse = getKillResponse(response3)
+    assert(killResponse.success)
+    assert(killResponse.submissionId === submissionId1)
+    // request status for both submissions: 1 should be KILLED but 2 should be 
RUNNING still
+    val response4 = client.requestSubmissionStatus(masterUrl, submissionId1)
+    val response5 = client.requestSubmissionStatus(masterUrl, submissionId2)
+    val statusResponse1 = getStatusResponse(response4)
+    val statusResponse2 = getStatusResponse(response5)
+    assert(statusResponse1.submissionId === submissionId1)
+    assert(statusResponse2.submissionId === submissionId2)
+    assert(statusResponse1.driverState === KILLED.toString)
+    assert(statusResponse2.driverState === RUNNING.toString)
+  }
+
+  test("kill or request status before create") {
+    val masterUrl = startSmartServer()
+    val doesNotExist = "does-not-exist"
+    // kill a non-existent submission
+    val response1 = client.killSubmission(masterUrl, doesNotExist)
+    val killResponse = getKillResponse(response1)
+    assert(!killResponse.success)
+    assert(killResponse.submissionId === doesNotExist)
+    // request status for a non-existent submission
+    val response2 = client.requestSubmissionStatus(masterUrl, doesNotExist)
+    val statusResponse = getStatusResponse(response2)
+    assert(!statusResponse.success)
+    assert(statusResponse.submissionId === doesNotExist)
+  }
+
+  /* ---------------------------------------- *
+   |     Aberrant client / server behavior    |
+   * ---------------------------------------- */
+
+  test("good request paths") {
+    val masterUrl = startSmartServer()
+    val httpUrl = masterUrl.replace("spark://", "http://";)
+    val v = StandaloneRestServer.PROTOCOL_VERSION
+    val json = constructSubmitRequest(masterUrl).toJson
+    val submitRequestPath = s"$httpUrl/$v/submissions/create"
+    val killRequestPath = s"$httpUrl/$v/submissions/kill"
+    val statusRequestPath = s"$httpUrl/$v/submissions/status"
+    val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, 
"POST", json)
+    val (response2, code2) = 
sendHttpRequestWithResponse(s"$killRequestPath/anything", "POST")
+    val (response3, code3) = 
sendHttpRequestWithResponse(s"$killRequestPath/any/thing", "POST")
+    val (response4, code4) = 
sendHttpRequestWithResponse(s"$statusRequestPath/anything", "GET")
+    val (response5, code5) = 
sendHttpRequestWithResponse(s"$statusRequestPath/any/thing", "GET")
+    // these should all succeed and the responses should be of the correct 
types
+    getSubmitResponse(response1)
+    val killResponse1 = getKillResponse(response2)
+    val killResponse2 = getKillResponse(response3)
+    val statusResponse1 = getStatusResponse(response4)
+    val statusResponse2 = getStatusResponse(response5)
+    assert(killResponse1.submissionId === "anything")
+    assert(killResponse2.submissionId === "any")
+    assert(statusResponse1.submissionId === "anything")
+    assert(statusResponse2.submissionId === "any")
+    assert(code1 === HttpServletResponse.SC_OK)
+    assert(code2 === HttpServletResponse.SC_OK)
+    assert(code3 === HttpServletResponse.SC_OK)
+    assert(code4 === HttpServletResponse.SC_OK)
+    assert(code5 === HttpServletResponse.SC_OK)
+  }
+
+  test("good request paths, bad requests") {
+    val masterUrl = startSmartServer()
+    val httpUrl = masterUrl.replace("spark://", "http://";)
+    val v = StandaloneRestServer.PROTOCOL_VERSION
+    val submitRequestPath = s"$httpUrl/$v/submissions/create"
+    val killRequestPath = s"$httpUrl/$v/submissions/kill"
+    val statusRequestPath = s"$httpUrl/$v/submissions/status"
+    val goodJson = constructSubmitRequest(masterUrl).toJson
+    val badJson1 = goodJson.replaceAll("action", "fraction") // invalid JSON
+    val badJson2 = goodJson.substring(goodJson.size / 2) // malformed JSON
+    val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, 
"POST") // missing JSON
+    val (response2, code2) = sendHttpRequestWithResponse(submitRequestPath, 
"POST", badJson1)
+    val (response3, code3) = sendHttpRequestWithResponse(submitRequestPath, 
"POST", badJson2)
+    val (response4, code4) = sendHttpRequestWithResponse(killRequestPath, 
"POST") // missing ID
+    val (response5, code5) = sendHttpRequestWithResponse(s"$killRequestPath/", 
"POST")
+    val (response6, code6) = sendHttpRequestWithResponse(statusRequestPath, 
"GET") // missing ID
+    val (response7, code7) = 
sendHttpRequestWithResponse(s"$statusRequestPath/", "GET")
+    // these should all fail as error responses
+    getErrorResponse(response1)
+    getErrorResponse(response2)
+    getErrorResponse(response3)
+    getErrorResponse(response4)
+    getErrorResponse(response5)
+    getErrorResponse(response6)
+    getErrorResponse(response7)
+    assert(code1 === HttpServletResponse.SC_BAD_REQUEST)
+    assert(code2 === HttpServletResponse.SC_BAD_REQUEST)
+    assert(code3 === HttpServletResponse.SC_BAD_REQUEST)
+    assert(code4 === HttpServletResponse.SC_BAD_REQUEST)
+    assert(code5 === HttpServletResponse.SC_BAD_REQUEST)
+    assert(code6 === HttpServletResponse.SC_BAD_REQUEST)
+    assert(code7 === HttpServletResponse.SC_BAD_REQUEST)
+  }
+
+  test("bad request paths") {
+    val masterUrl = startSmartServer()
+    val httpUrl = masterUrl.replace("spark://", "http://";)
+    val v = StandaloneRestServer.PROTOCOL_VERSION
+    val (response1, code1) = sendHttpRequestWithResponse(httpUrl, "GET")
+    val (response2, code2) = sendHttpRequestWithResponse(s"$httpUrl/", "GET")
+    val (response3, code3) = sendHttpRequestWithResponse(s"$httpUrl/$v", "GET")
+    val (response4, code4) = sendHttpRequestWithResponse(s"$httpUrl/$v/", 
"GET")
+    val (response5, code5) = 
sendHttpRequestWithResponse(s"$httpUrl/$v/submissions", "GET")
+    val (response6, code6) = 
sendHttpRequestWithResponse(s"$httpUrl/$v/submissions/", "GET")
+    val (response7, code7) = 
sendHttpRequestWithResponse(s"$httpUrl/$v/submissions/bad", "GET")
+    val (response8, code8) = 
sendHttpRequestWithResponse(s"$httpUrl/bad-version", "GET")
+    assert(code1 === HttpServletResponse.SC_BAD_REQUEST)
+    assert(code2 === HttpServletResponse.SC_BAD_REQUEST)
+    assert(code3 === HttpServletResponse.SC_BAD_REQUEST)
+    assert(code4 === HttpServletResponse.SC_BAD_REQUEST)
+    assert(code5 === HttpServletResponse.SC_BAD_REQUEST)
+    assert(code6 === HttpServletResponse.SC_BAD_REQUEST)
+    assert(code7 === HttpServletResponse.SC_BAD_REQUEST)
+    assert(code8 === StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION)
+    // all responses should be error responses
+    val errorResponse1 = getErrorResponse(response1)
+    val errorResponse2 = getErrorResponse(response2)
+    val errorResponse3 = getErrorResponse(response3)
+    val errorResponse4 = getErrorResponse(response4)
+    val errorResponse5 = getErrorResponse(response5)
+    val errorResponse6 = getErrorResponse(response6)
+    val errorResponse7 = getErrorResponse(response7)
+    val errorResponse8 = getErrorResponse(response8)
+    // only the incompatible version response should have server protocol 
version set
+    assert(errorResponse1.highestProtocolVersion === null)
+    assert(errorResponse2.highestProtocolVersion === null)
+    assert(errorResponse3.highestProtocolVersion === null)
+    assert(errorResponse4.highestProtocolVersion === null)
+    assert(errorResponse5.highestProtocolVersion === null)
+    assert(errorResponse6.highestProtocolVersion === null)
+    assert(errorResponse7.highestProtocolVersion === null)
+    assert(errorResponse8.highestProtocolVersion === 
StandaloneRestServer.PROTOCOL_VERSION)
+  }
+
+  test("server returns unknown fields") {
+    val masterUrl = startSmartServer()
+    val httpUrl = masterUrl.replace("spark://", "http://";)
+    val v = StandaloneRestServer.PROTOCOL_VERSION
+    val submitRequestPath = s"$httpUrl/$v/submissions/create"
+    val oldJson = constructSubmitRequest(masterUrl).toJson
+    val oldFields = parse(oldJson).asInstanceOf[JObject].obj
+    val newFields = oldFields ++ Seq(
+      JField("tomato", JString("not-a-fruit")),
+      JField("potato", JString("not-po-tah-to"))
+    )
+    val newJson = pretty(render(JObject(newFields)))
+    // send two requests, one with the unknown fields and the other without
+    val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, 
"POST", oldJson)
+    val (response2, code2) = sendHttpRequestWithResponse(submitRequestPath, 
"POST", newJson)
+    val submitResponse1 = getSubmitResponse(response1)
+    val submitResponse2 = getSubmitResponse(response2)
+    assert(code1 === HttpServletResponse.SC_OK)
+    assert(code2 === HttpServletResponse.SC_OK)
+    // only the response to the modified request should have unknown fields set
+    assert(submitResponse1.unknownFields === null)
+    assert(submitResponse2.unknownFields === Array("tomato", "potato"))
+  }
+
+  test("client handles faulty server") {
+    val masterUrl = startFaultyServer()
+    val httpUrl = masterUrl.replace("spark://", "http://";)
+    val v = StandaloneRestServer.PROTOCOL_VERSION
+    val submitRequestPath = s"$httpUrl/$v/submissions/create"
+    val killRequestPath = s"$httpUrl/$v/submissions/kill/anything"
+    val statusRequestPath = s"$httpUrl/$v/submissions/status/anything"
+    val json = constructSubmitRequest(masterUrl).toJson
+    // server returns malformed response unwittingly
+    // client should throw an appropriate exception to indicate server failure
+    val conn1 = sendHttpRequest(submitRequestPath, "POST", json)
+    intercept[SubmitRestProtocolException] { client.readResponse(conn1) }
+    // server attempts to send invalid response, but fails internally on 
validation
+    // client should receive an error response as server is able to recover
+    val conn2 = sendHttpRequest(killRequestPath, "POST")
+    val response2 = client.readResponse(conn2)
+    getErrorResponse(response2)
+    assert(conn2.getResponseCode === 
HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
+    // server explodes internally beyond recovery
+    // client should throw an appropriate exception to indicate server failure
+    val conn3 = sendHttpRequest(statusRequestPath, "GET")
+    intercept[SubmitRestProtocolException] { client.readResponse(conn3) } // 
empty response
+    assert(conn3.getResponseCode === 
HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
+  }
+
+  /* --------------------- *
+   |     Helper methods    |
+   * --------------------- */
+
+  /** Start a dummy server that responds to requests using the specified 
parameters. */
+  private def startDummyServer(
+      submitId: String = "fake-driver-id",
+      submitMessage: String = "driver is submitted",
+      killMessage: String = "driver is killed",
+      state: DriverState = FINISHED,
+      exception: Option[Exception] = None): String = {
+    startServer(new DummyMaster(submitId, submitMessage, killMessage, state, 
exception))
+  }
+
+  /** Start a smarter dummy server that keeps track of submitted driver 
states. */
+  private def startSmartServer(): String = {
+    startServer(new SmarterMaster)
+  }
+
+  /** Start a dummy server that is faulty in many ways... */
+  private def startFaultyServer(): String = {
+    startServer(new DummyMaster, faulty = true)
   }
 
   /**
-   * Start a local cluster containing one Master and a few Workers.
-   * Do not use [[org.apache.spark.deploy.LocalSparkCluster]] here because we 
want the REST URL.
-   * Return the Master's REST URL to which applications should be submitted.
+   * Start a [[StandaloneRestServer]] that communicates with the given actor.
+   * If `faulty` is true, start an [[FaultyStandaloneRestServer]] instead.
+   * Return the master URL that corresponds to the address of this server.
    */
-  private def startLocalCluster(): String = {
-    val conf = new SparkConf(false)
-      .set("spark.master.rest.enabled", "true")
-      .set("spark.master.rest.port", "0")
-    val (numWorkers, coresPerWorker, memPerWorker) = (2, 1, 512)
-    val localHostName = Utils.localHostName()
-    val (masterSystem, masterPort, _, _masterRestPort) =
-      Master.startSystemAndActor(localHostName, 0, 0, conf)
-    val masterRestPort = _masterRestPort.getOrElse { fail("REST server not 
started on Master!") }
-    val masterUrl = "spark://" + localHostName + ":" + masterPort
-    val masterRestUrl = "spark://" + localHostName + ":" + masterRestPort
-    (1 to numWorkers).foreach { n =>
-      val (workerSystem, _) = Worker.startSystemAndActor(
-        localHostName, 0, 0, coresPerWorker, memPerWorker, Array(masterUrl), 
null, Some(n))
-      systemsToStop.append(workerSystem)
-    }
-    systemsToStop.append(masterSystem)
-    masterRestUrl
+  private def startServer(makeFakeMaster: => Actor, faulty: Boolean = false): 
String = {
+    val name = "test-standalone-rest-protocol"
+    val conf = new SparkConf
+    val localhost = Utils.localHostName()
+    val securityManager = new SecurityManager(conf)
+    val (_actorSystem, _) = AkkaUtils.createActorSystem(name, localhost, 0, 
conf, securityManager)
+    val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster))
+    val _server =
+      if (faulty) {
+        new FaultyStandaloneRestServer(localhost, 0, fakeMasterRef, 
"spark://fake:7077", conf)
+      } else {
+        new StandaloneRestServer(localhost, 0, fakeMasterRef, 
"spark://fake:7077", conf)
+      }
+    val port = _server.start()
+    // set these to clean them up after every test
+    actorSystem = Some(_actorSystem)
+    server = Some(_server)
+    s"spark://$localhost:$port"
   }
 
-  /** Submit the [[StandaloneRestApp]] and return the corresponding submission 
ID. */
-  private def submitApplication(resultsFile: File, numbers: Seq[Int], size: 
Int): String = {
-    val appArgs = Seq(resultsFile.getAbsolutePath) ++ numbers.map(_.toString) 
++ Seq(size.toString)
+  /** Create a submit request with real parameters using Spark submit. */
+  private def constructSubmitRequest(
+      masterUrl: String,
+      appArgs: Array[String] = Array.empty): CreateSubmissionRequest = {
+    val mainClass = "main-class-not-used"
+    val mainJar = "dummy-jar-not-used.jar"
     val commandLineArgs = Array(
       "--deploy-mode", "cluster",
-      "--master", masterRestUrl,
+      "--master", masterUrl,
       "--name", mainClass,
       "--class", mainClass,
       mainJar) ++ appArgs
     val args = new SparkSubmitArguments(commandLineArgs)
     val (_, _, sparkProperties, _) = SparkSubmit.prepareSubmitEnvironment(args)
-    val request = client.constructSubmitRequest(
-      mainJar, mainClass, appArgs.toArray, sparkProperties.toMap, Map.empty)
-    val response = client.createSubmission(masterRestUrl, request)
-    val submitResponse = getSubmitResponse(response)
-    val submissionId = submitResponse.submissionId
-    assert(submissionId != null, "Application submission was unsuccessful!")
-    submissionId
-  }
-
-  /** Wait until the given submission has finished running up to the specified 
timeout. */
-  private def waitUntilFinished(submissionId: String, maxSeconds: Int = 30): 
Unit = {
-    var finished = false
-    val expireTime = System.currentTimeMillis + maxSeconds * 1000
-    while (!finished) {
-      val response = client.requestSubmissionStatus(masterRestUrl, 
submissionId)
-      val statusResponse = getStatusResponse(response)
-      val driverState = statusResponse.driverState
-      finished =
-        driverState != DriverState.SUBMITTED.toString &&
-        driverState != DriverState.RUNNING.toString
-      if (System.currentTimeMillis > expireTime) {
-        fail(s"Driver $submissionId did not finish within $maxSeconds 
seconds.")
-      }
-    }
+    client.constructSubmitRequest(
+      mainJar, mainClass, appArgs, sparkProperties.toMap, Map.empty)
   }
 
   /** Return the response as a submit response, or fail with error otherwise. 
*/
@@ -181,85 +452,151 @@ class StandaloneRestSubmitSuite extends FunSuite with 
BeforeAndAfterAll with Bef
     }
   }
 
-  /** Validate whether the application produced the corrupt output. */
-  private def validateResult(resultsFile: File, numbers: Seq[Int], size: Int): 
Unit = {
-    val lines = Source.fromFile(resultsFile.getAbsolutePath).getLines().toSeq
-    val unexpectedContent =
-      if (lines.nonEmpty) {
-        "[\n" + lines.map { l => "  " + l }.mkString("\n") + "\n]"
-      } else {
-        "[EMPTY]"
-      }
-    assert(lines.size === 2, s"Unexpected content in file: $unexpectedContent")
-    assert(lines(0).toInt === numbers.sum, s"Sum of ${numbers.mkString(",")} 
is incorrect")
-    assert(lines(1).toInt === (size / 2) + 1, "Result of Spark job is 
incorrect")
+  /** Return the response as an error response, or fail if the response was 
not an error. */
+  private def getErrorResponse(response: SubmitRestProtocolResponse): 
ErrorResponse = {
+    response match {
+      case e: ErrorResponse => e
+      case r => fail(s"Expected error response. Actual: ${r.toJson}")
+    }
   }
-}
-
-private object StandaloneRestSubmitSuite {
-  private val pathPrefix = this.getClass.getPackage.getName.replaceAll("\\.", 
"/")
 
   /**
-   * Create a jar that contains all the class files needed for running the 
[[StandaloneRestApp]].
-   * Return the absolute path to that jar.
+   * Send an HTTP request to the given URL using the method and the body 
specified.
+   * Return the connection object.
    */
-  def createJar(): String = {
-    val jarFile = File.createTempFile("test-standalone-rest-protocol", ".jar")
-    val jarFileStream = new FileOutputStream(jarFile)
-    val jarStream = new JarOutputStream(jarFileStream, new 
java.util.jar.Manifest)
-    jarStream.putNextEntry(new ZipEntry(pathPrefix))
-    getClassFiles.foreach { cf =>
-      jarStream.putNextEntry(new JarEntry(pathPrefix + "/" + cf.getName))
-      val in = new FileInputStream(cf)
-      ByteStreams.copy(in, jarStream)
-      in.close()
+  private def sendHttpRequest(
+      url: String,
+      method: String,
+      body: String = ""): HttpURLConnection = {
+    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
+    conn.setRequestMethod(method)
+    if (body.nonEmpty) {
+      conn.setDoOutput(true)
+      val out = new DataOutputStream(conn.getOutputStream)
+      out.write(body.getBytes(Charsets.UTF_8))
+      out.close()
     }
-    jarStream.close()
-    jarFileStream.close()
-    jarFile.getAbsolutePath
+    conn
   }
 
   /**
-   * Return a list of class files compiled for [[StandaloneRestApp]].
-   * This includes all the anonymous classes used in the application.
+   * Send an HTTP request to the given URL using the method and the body 
specified.
+   * Return a 2-tuple of the response message from the server and the response 
code.
    */
-  private def getClassFiles: Seq[File] = {
-    val className = Utils.getFormattedClassName(StandaloneRestApp)
-    val clazz = StandaloneRestApp.getClass
-    val basePath = 
clazz.getProtectionDomain.getCodeSource.getLocation.toURI.getPath
-    val baseDir = new File(basePath + "/" + pathPrefix)
-    baseDir.listFiles().filter(_.getName.contains(className))
+  private def sendHttpRequestWithResponse(
+      url: String,
+      method: String,
+      body: String = ""): (SubmitRestProtocolResponse, Int) = {
+    val conn = sendHttpRequest(url, method, body)
+    (client.readResponse(conn), conn.getResponseCode)
   }
 }
 
 /**
- * Sample application to be submitted to the cluster using the REST gateway.
- * All relevant classes will be packaged into a jar at run time.
+ * A mock standalone Master that responds with dummy messages.
+ * In all responses, the success parameter is always true.
  */
-object StandaloneRestApp {
-  // Usage: [path to results file] [num1] [num2] [num3] [rddSize]
-  // The first line of the results file should be (num1 + num2 + num3)
-  // The second line should be (rddSize / 2) + 1
-  def main(args: Array[String]) {
-    assert(args.size == 5, s"Expected exactly 5 arguments: 
${args.mkString(",")}")
-    val resultFile = new File(args(0))
-    val writer = new PrintWriter(resultFile)
-    try {
-      val conf = new SparkConf()
-      val sc = new SparkContext(conf)
-      val firstLine = args(1).toInt + args(2).toInt + args(3).toInt
-      val secondLine = sc.parallelize(1 to args(4).toInt)
-        .map { i => (i / 2, i) }
-        .reduceByKey(_ + _)
-        .count()
-      writer.println(firstLine)
-      writer.println(secondLine)
-    } catch {
-      case e: Exception =>
-        writer.println(e)
-        e.getStackTrace.foreach { l => writer.println("  " + l) }
-    } finally {
-      writer.close()
+private class DummyMaster(
+    submitId: String = "fake-driver-id",
+    submitMessage: String = "submitted",
+    killMessage: String = "killed",
+    state: DriverState = FINISHED,
+    exception: Option[Exception] = None)
+  extends Actor {
+
+  override def receive = {
+    case RequestSubmitDriver(driverDesc) =>
+      sender ! SubmitDriverResponse(success = true, Some(submitId), 
submitMessage)
+    case RequestKillDriver(driverId) =>
+      sender ! KillDriverResponse(driverId, success = true, killMessage)
+    case RequestDriverStatus(driverId) =>
+      sender ! DriverStatusResponse(found = true, Some(state), None, None, 
exception)
+  }
+}
+
+/**
+ * A mock standalone Master that keeps track of drivers that have been 
submitted.
+ *
+ * If a driver is submitted, its state is immediately set to RUNNING.
+ * If an existing driver is killed, its state is immediately set to KILLED.
+ * If an existing driver's status is requested, its state is returned in the 
response.
+ * Submits are always successful while kills and status requests are 
successful only
+ * if the driver was submitted in the past.
+ */
+private class SmarterMaster extends Actor {
+  private var counter: Int = 0
+  private val submittedDrivers = new mutable.HashMap[String, DriverState]
+
+  override def receive = {
+    case RequestSubmitDriver(driverDesc) =>
+      val driverId = s"driver-$counter"
+      submittedDrivers(driverId) = RUNNING
+      counter += 1
+      sender ! SubmitDriverResponse(success = true, Some(driverId), 
"submitted")
+
+    case RequestKillDriver(driverId) =>
+      val success = submittedDrivers.contains(driverId)
+      if (success) {
+        submittedDrivers(driverId) = KILLED
+      }
+      sender ! KillDriverResponse(driverId, success, "killed")
+
+    case RequestDriverStatus(driverId) =>
+      val found = submittedDrivers.contains(driverId)
+      val state = submittedDrivers.get(driverId)
+      sender ! DriverStatusResponse(found, state, None, None, None)
+  }
+}
+
+/**
+ * A [[StandaloneRestServer]] that is faulty in many ways.
+ *
+ * When handling a submit request, the server returns a malformed JSON.
+ * When handling a kill request, the server returns an invalid JSON.
+ * When handling a status request, the server throws an internal exception.
+ * The purpose of this class is to test that client handles these cases 
gracefully.
+ */
+private class FaultyStandaloneRestServer(
+    host: String,
+    requestedPort: Int,
+    masterActor: ActorRef,
+    masterUrl: String,
+    masterConf: SparkConf)
+  extends StandaloneRestServer(host, requestedPort, masterActor, masterUrl, 
masterConf) {
+
+  protected override val contextToServlet = Map[String, StandaloneRestServlet](
+    s"$baseContext/create/*" -> new MalformedSubmitServlet,
+    s"$baseContext/kill/*" -> new InvalidKillServlet,
+    s"$baseContext/status/*" -> new ExplodingStatusServlet,
+    "/*" -> new ErrorServlet
+  )
+
+  /** A faulty servlet that produces malformed responses. */
+  class MalformedSubmitServlet extends SubmitRequestServlet(masterActor, 
masterUrl, masterConf) {
+    protected override def sendResponse(
+        responseMessage: SubmitRestProtocolResponse,
+        responseServlet: HttpServletResponse): Unit = {
+      val badJson = responseMessage.toJson.drop(10).dropRight(20)
+      responseServlet.getWriter.write(badJson)
+    }
+  }
+
+  /** A faulty servlet that produces invalid responses. */
+  class InvalidKillServlet extends KillRequestServlet(masterActor, masterConf) 
{
+    protected override def handleKill(submissionId: String): 
KillSubmissionResponse = {
+      val k = super.handleKill(submissionId)
+      k.submissionId = null
+      k
+    }
+  }
+
+  /** A faulty status servlet that explodes. */
+  class ExplodingStatusServlet extends StatusRequestServlet(masterActor, 
masterConf) {
+    private def explode: Int = 1 / 0
+    protected override def handleStatus(submissionId: String): 
SubmissionStatusResponse = {
+      val s = super.handleStatus(submissionId)
+      s.workerId = explode.toString
+      s
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to