http://git-wip-us.apache.org/repos/asf/spark/blob/81e5619c/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
----------------------------------------------------------------------
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
new file mode 100644
index 0000000..9c3b18e
--- /dev/null
+++ 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.{File, FileOutputStream, OutputStreamWriter}
+import java.nio.charset.StandardCharsets
+import java.util.Properties
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import com.google.common.io.Files
+import org.apache.commons.lang3.SerializationUtils
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.server.MiniYARNCluster
+import org.scalatest.{BeforeAndAfterAll, Matchers}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark._
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.Logging
+import org.apache.spark.launcher._
+import org.apache.spark.util.Utils
+
+abstract class BaseYarnClusterSuite
+  extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging {
+
+  // log4j configuration for the YARN containers, so that their output is 
collected
+  // by YARN instead of trying to overwrite unit-tests.log.
+  protected val LOG4J_CONF = """
+    |log4j.rootCategory=DEBUG, console
+    |log4j.appender.console=org.apache.log4j.ConsoleAppender
+    |log4j.appender.console.target=System.err
+    |log4j.appender.console.layout=org.apache.log4j.PatternLayout
+    |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p 
%c{1}: %m%n
+    |log4j.logger.org.apache.hadoop=WARN
+    |log4j.logger.org.eclipse.jetty=WARN
+    |log4j.logger.org.mortbay=WARN
+    |log4j.logger.org.spark_project.jetty=WARN
+    """.stripMargin
+
+  private var yarnCluster: MiniYARNCluster = _
+  protected var tempDir: File = _
+  private var fakeSparkJar: File = _
+  protected var hadoopConfDir: File = _
+  private var logConfDir: File = _
+
+  var oldSystemProperties: Properties = null
+
+  def newYarnConfig(): YarnConfiguration
+
+  override def beforeAll() {
+    super.beforeAll()
+    oldSystemProperties = SerializationUtils.clone(System.getProperties)
+
+    tempDir = Utils.createTempDir()
+    logConfDir = new File(tempDir, "log4j")
+    logConfDir.mkdir()
+    System.setProperty("SPARK_YARN_MODE", "true")
+
+    val logConfFile = new File(logConfDir, "log4j.properties")
+    Files.write(LOG4J_CONF, logConfFile, StandardCharsets.UTF_8)
+
+    // Disable the disk utilization check to avoid the test hanging when 
people's disks are
+    // getting full.
+    val yarnConf = newYarnConfig()
+    
yarnConf.set("yarn.nodemanager.disk-health-checker.max-disk-utilization-per-disk-percentage",
+      "100.0")
+
+    yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1)
+    yarnCluster.init(yarnConf)
+    yarnCluster.start()
+
+    // There's a race in MiniYARNCluster in which start() may return before 
the RM has updated
+    // its address in the configuration. You can see this in the logs by 
noticing that when
+    // MiniYARNCluster prints the address, it still has port "0" assigned, 
although later the
+    // test works sometimes:
+    //
+    //    INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0
+    //
+    // That log message prints the contents of the RM_ADDRESS config variable. 
If you check it
+    // later on, it looks something like this:
+    //
+    //    INFO YarnClusterSuite: RM address in configuration is blah:42631
+    //
+    // This hack loops for a bit waiting for the port to change, and fails the 
test if it hasn't
+    // done so in a timely manner (defined to be 10 seconds).
+    val config = yarnCluster.getConfig()
+    val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10)
+    while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") {
+      if (System.currentTimeMillis() > deadline) {
+        throw new IllegalStateException("Timed out waiting for RM to come up.")
+      }
+      logDebug("RM address still not set in configuration, waiting...")
+      TimeUnit.MILLISECONDS.sleep(100)
+    }
+
+    logInfo(s"RM address in configuration is 
${config.get(YarnConfiguration.RM_ADDRESS)}")
+
+    fakeSparkJar = File.createTempFile("sparkJar", null, tempDir)
+    hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR)
+    assert(hadoopConfDir.mkdir())
+    File.createTempFile("token", ".txt", hadoopConfDir)
+  }
+
+  override def afterAll() {
+    try {
+      yarnCluster.stop()
+    } finally {
+      System.setProperties(oldSystemProperties)
+      super.afterAll()
+    }
+  }
+
+  protected def runSpark(
+      clientMode: Boolean,
+      klass: String,
+      appArgs: Seq[String] = Nil,
+      sparkArgs: Seq[(String, String)] = Nil,
+      extraClassPath: Seq[String] = Nil,
+      extraJars: Seq[String] = Nil,
+      extraConf: Map[String, String] = Map(),
+      extraEnv: Map[String, String] = Map()): SparkAppHandle.State = {
+    val deployMode = if (clientMode) "client" else "cluster"
+    val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf 
= extraConf)
+    val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ 
extraEnv
+
+    val launcher = new SparkLauncher(env.asJava)
+    if (klass.endsWith(".py")) {
+      launcher.setAppResource(klass)
+    } else {
+      launcher.setMainClass(klass)
+      launcher.setAppResource(fakeSparkJar.getAbsolutePath())
+    }
+    launcher.setSparkHome(sys.props("spark.test.home"))
+      .setMaster("yarn")
+      .setDeployMode(deployMode)
+      .setConf("spark.executor.instances", "1")
+      .setPropertiesFile(propsFile)
+      .addAppArgs(appArgs.toArray: _*)
+
+    sparkArgs.foreach { case (name, value) =>
+      if (value != null) {
+        launcher.addSparkArg(name, value)
+      } else {
+        launcher.addSparkArg(name)
+      }
+    }
+    extraJars.foreach(launcher.addJar)
+
+    val handle = launcher.startApplication()
+    try {
+      eventually(timeout(2 minutes), interval(1 second)) {
+        assert(handle.getState().isFinal())
+      }
+    } finally {
+      handle.kill()
+    }
+
+    handle.getState()
+  }
+
+  /**
+   * This is a workaround for an issue with yarn-cluster mode: the Client 
class will not provide
+   * any sort of error when the job process finishes successfully, but the job 
itself fails. So
+   * the tests enforce that something is written to a file after everything is 
ok to indicate
+   * that the job succeeded.
+   */
+  protected def checkResult(finalState: SparkAppHandle.State, result: File): 
Unit = {
+    checkResult(finalState, result, "success")
+  }
+
+  protected def checkResult(
+      finalState: SparkAppHandle.State,
+      result: File,
+      expected: String): Unit = {
+    finalState should be (SparkAppHandle.State.FINISHED)
+    val resultString = Files.toString(result, StandardCharsets.UTF_8)
+    resultString should be (expected)
+  }
+
+  protected def mainClassName(klass: Class[_]): String = {
+    klass.getName().stripSuffix("$")
+  }
+
+  protected def createConfFile(
+      extraClassPath: Seq[String] = Nil,
+      extraConf: Map[String, String] = Map()): String = {
+    val props = new Properties()
+    props.put(SPARK_JARS.key, "local:" + fakeSparkJar.getAbsolutePath())
+
+    val testClasspath = new TestClasspathBuilder()
+      .buildClassPath(
+        logConfDir.getAbsolutePath() +
+        File.pathSeparator +
+        extraClassPath.mkString(File.pathSeparator))
+      .asScala
+      .mkString(File.pathSeparator)
+
+    props.put("spark.driver.extraClassPath", testClasspath)
+    props.put("spark.executor.extraClassPath", testClasspath)
+
+    // SPARK-4267: make sure java options are propagated correctly.
+    props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two 
three\"")
+    props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two 
three\"")
+
+    yarnCluster.getConfig().asScala.foreach { e =>
+      props.setProperty("spark.hadoop." + e.getKey(), e.getValue())
+    }
+    sys.props.foreach { case (k, v) =>
+      if (k.startsWith("spark.")) {
+        props.setProperty(k, v)
+      }
+    }
+    extraConf.foreach { case (k, v) => props.setProperty(k, v) }
+
+    val propsFile = File.createTempFile("spark", ".properties", tempDir)
+    val writer = new OutputStreamWriter(new FileOutputStream(propsFile), 
StandardCharsets.UTF_8)
+    props.store(writer, "Spark properties.")
+    writer.close()
+    propsFile.getAbsolutePath()
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/81e5619c/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
new file mode 100644
index 0000000..b696e08
--- /dev/null
+++ 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Map
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FileStatus
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.yarn.api.records.LocalResource
+import org.apache.hadoop.yarn.api.records.LocalResourceType
+import org.apache.hadoop.yarn.api.records.LocalResourceVisibility
+import org.apache.hadoop.yarn.util.ConverterUtils
+import org.mockito.Mockito.when
+import org.scalatest.mock.MockitoSugar
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.yarn.config._
+
+class ClientDistributedCacheManagerSuite extends SparkFunSuite with 
MockitoSugar {
+
+  class MockClientDistributedCacheManager extends 
ClientDistributedCacheManager {
+    override def getVisibility(conf: Configuration, uri: URI, statCache: 
Map[URI, FileStatus]):
+        LocalResourceVisibility = {
+      LocalResourceVisibility.PRIVATE
+    }
+  }
+
+  test("test getFileStatus empty") {
+    val distMgr = new ClientDistributedCacheManager()
+    val fs = mock[FileSystem]
+    val uri = new URI("/tmp/testing")
+    when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
+    val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+    val stat = distMgr.getFileStatus(fs, uri, statCache)
+    assert(stat.getPath() === null)
+  }
+
+  test("test getFileStatus cached") {
+    val distMgr = new ClientDistributedCacheManager()
+    val fs = mock[FileSystem]
+    val uri = new URI("/tmp/testing")
+    val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, 
"testOwner",
+      null, new Path("/tmp/testing"))
+    when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
+    val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> 
realFileStatus)
+    val stat = distMgr.getFileStatus(fs, uri, statCache)
+    assert(stat.getPath().toString() === "/tmp/testing")
+  }
+
+  test("test addResource") {
+    val distMgr = new MockClientDistributedCacheManager()
+    val fs = mock[FileSystem]
+    val conf = new Configuration()
+    val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+    val localResources = HashMap[String, LocalResource]()
+    val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+    when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
+
+    distMgr.addResource(fs, conf, destPath, localResources, 
LocalResourceType.FILE, "link",
+      statCache, false)
+    val resource = localResources("link")
+    assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+    assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === 
destPath)
+    assert(resource.getTimestamp() === 0)
+    assert(resource.getSize() === 0)
+    assert(resource.getType() === LocalResourceType.FILE)
+
+    val sparkConf = new SparkConf(false)
+    distMgr.updateConfiguration(sparkConf)
+    assert(sparkConf.get(CACHED_FILES) === 
Seq("file:/foo.invalid.com:8080/tmp/testing#link"))
+    assert(sparkConf.get(CACHED_FILES_TIMESTAMPS) === Seq(0L))
+    assert(sparkConf.get(CACHED_FILES_SIZES) === Seq(0L))
+    assert(sparkConf.get(CACHED_FILES_VISIBILITIES) === 
Seq(LocalResourceVisibility.PRIVATE.name()))
+    assert(sparkConf.get(CACHED_FILES_TYPES) === 
Seq(LocalResourceType.FILE.name()))
+
+    // add another one and verify both there and order correct
+    val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, 
"testOwner",
+      null, new Path("/tmp/testing2"))
+    val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2")
+    when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus)
+    distMgr.addResource(fs, conf, destPath2, localResources, 
LocalResourceType.FILE, "link2",
+      statCache, false)
+    val resource2 = localResources("link2")
+    assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE)
+    assert(ConverterUtils.getPathFromYarnURL(resource2.getResource()) === 
destPath2)
+    assert(resource2.getTimestamp() === 10)
+    assert(resource2.getSize() === 20)
+    assert(resource2.getType() === LocalResourceType.FILE)
+
+    val sparkConf2 = new SparkConf(false)
+    distMgr.updateConfiguration(sparkConf2)
+
+    val files = sparkConf2.get(CACHED_FILES)
+    val sizes = sparkConf2.get(CACHED_FILES_SIZES)
+    val timestamps = sparkConf2.get(CACHED_FILES_TIMESTAMPS)
+    val visibilities = sparkConf2.get(CACHED_FILES_VISIBILITIES)
+
+    assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link")
+    assert(timestamps(0)  === 0)
+    assert(sizes(0)  === 0)
+    assert(visibilities(0) === LocalResourceVisibility.PRIVATE.name())
+
+    assert(files(1) === "file:/foo.invalid.com:8080/tmp/testing2#link2")
+    assert(timestamps(1)  === 10)
+    assert(sizes(1)  === 20)
+    assert(visibilities(1) === LocalResourceVisibility.PRIVATE.name())
+  }
+
+  test("test addResource link null") {
+    val distMgr = new MockClientDistributedCacheManager()
+    val fs = mock[FileSystem]
+    val conf = new Configuration()
+    val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+    val localResources = HashMap[String, LocalResource]()
+    val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+    when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
+
+    intercept[Exception] {
+      distMgr.addResource(fs, conf, destPath, localResources, 
LocalResourceType.FILE, null,
+        statCache, false)
+    }
+    assert(localResources.get("link") === None)
+    assert(localResources.size === 0)
+  }
+
+  test("test addResource appmaster only") {
+    val distMgr = new MockClientDistributedCacheManager()
+    val fs = mock[FileSystem]
+    val conf = new Configuration()
+    val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+    val localResources = HashMap[String, LocalResource]()
+    val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+    val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, 
"testOwner",
+      null, new Path("/tmp/testing"))
+    when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
+
+    distMgr.addResource(fs, conf, destPath, localResources, 
LocalResourceType.ARCHIVE, "link",
+      statCache, true)
+    val resource = localResources("link")
+    assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+    assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === 
destPath)
+    assert(resource.getTimestamp() === 10)
+    assert(resource.getSize() === 20)
+    assert(resource.getType() === LocalResourceType.ARCHIVE)
+
+    val sparkConf = new SparkConf(false)
+    distMgr.updateConfiguration(sparkConf)
+    assert(sparkConf.get(CACHED_FILES) === Nil)
+    assert(sparkConf.get(CACHED_FILES_TIMESTAMPS) === Nil)
+    assert(sparkConf.get(CACHED_FILES_SIZES) === Nil)
+    assert(sparkConf.get(CACHED_FILES_VISIBILITIES) === Nil)
+    assert(sparkConf.get(CACHED_FILES_TYPES) === Nil)
+  }
+
+  test("test addResource archive") {
+    val distMgr = new MockClientDistributedCacheManager()
+    val fs = mock[FileSystem]
+    val conf = new Configuration()
+    val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+    val localResources = HashMap[String, LocalResource]()
+    val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+    val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, 
"testOwner",
+      null, new Path("/tmp/testing"))
+    when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
+
+    distMgr.addResource(fs, conf, destPath, localResources, 
LocalResourceType.ARCHIVE, "link",
+      statCache, false)
+    val resource = localResources("link")
+    assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+    assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === 
destPath)
+    assert(resource.getTimestamp() === 10)
+    assert(resource.getSize() === 20)
+    assert(resource.getType() === LocalResourceType.ARCHIVE)
+
+    val sparkConf = new SparkConf(false)
+    distMgr.updateConfiguration(sparkConf)
+    assert(sparkConf.get(CACHED_FILES) === 
Seq("file:/foo.invalid.com:8080/tmp/testing#link"))
+    assert(sparkConf.get(CACHED_FILES_SIZES) === Seq(20L))
+    assert(sparkConf.get(CACHED_FILES_TIMESTAMPS) === Seq(10L))
+    assert(sparkConf.get(CACHED_FILES_VISIBILITIES) === 
Seq(LocalResourceVisibility.PRIVATE.name()))
+    assert(sparkConf.get(CACHED_FILES_TYPES) === 
Seq(LocalResourceType.ARCHIVE.name()))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/81e5619c/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
----------------------------------------------------------------------
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
new file mode 100644
index 0000000..7deaf0a
--- /dev/null
+++ 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -0,0 +1,462 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.{File, FileInputStream, FileOutputStream}
+import java.net.URI
+import java.util.Properties
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.{HashMap => MutableHashMap}
+import scala.reflect.ClassTag
+import scala.util.Try
+
+import org.apache.commons.lang3.SerializationUtils
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.MRJobConfig
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.YarnClientApplication
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.util.Records
+import org.mockito.Matchers.{eq => meq, _}
+import org.mockito.Mockito._
+import org.scalatest.{BeforeAndAfterAll, Matchers}
+
+import org.apache.spark.{SparkConf, SparkFunSuite, TestUtils}
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.util.{ResetSystemProperties, SparkConfWithEnv, Utils}
+
+class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
+  with ResetSystemProperties {
+
+  import Client._
+
+  var oldSystemProperties: Properties = null
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    oldSystemProperties = SerializationUtils.clone(System.getProperties)
+    System.setProperty("SPARK_YARN_MODE", "true")
+  }
+
+  override def afterAll(): Unit = {
+    try {
+      System.setProperties(oldSystemProperties)
+      oldSystemProperties = null
+    } finally {
+      super.afterAll()
+    }
+  }
+
+  test("default Yarn application classpath") {
+    getDefaultYarnApplicationClasspath should 
be(Some(Fixtures.knownDefYarnAppCP))
+  }
+
+  test("default MR application classpath") {
+    getDefaultMRApplicationClasspath should be(Some(Fixtures.knownDefMRAppCP))
+  }
+
+  test("resultant classpath for an application that defines a classpath for 
YARN") {
+    withAppConf(Fixtures.mapYARNAppConf) { conf =>
+      val env = newEnv
+      populateHadoopClasspath(conf, env)
+      classpath(env) should be(
+        flatten(Fixtures.knownYARNAppCP, getDefaultMRApplicationClasspath))
+    }
+  }
+
+  test("resultant classpath for an application that defines a classpath for 
MR") {
+    withAppConf(Fixtures.mapMRAppConf) { conf =>
+      val env = newEnv
+      populateHadoopClasspath(conf, env)
+      classpath(env) should be(
+        flatten(getDefaultYarnApplicationClasspath, Fixtures.knownMRAppCP))
+    }
+  }
+
+  test("resultant classpath for an application that defines both classpaths, 
YARN and MR") {
+    withAppConf(Fixtures.mapAppConf) { conf =>
+      val env = newEnv
+      populateHadoopClasspath(conf, env)
+      classpath(env) should be(flatten(Fixtures.knownYARNAppCP, 
Fixtures.knownMRAppCP))
+    }
+  }
+
+  private val SPARK = "local:/sparkJar"
+  private val USER = "local:/userJar"
+  private val ADDED = "local:/addJar1,local:/addJar2,/addJar3"
+
+  private val PWD =
+    if (classOf[Environment].getMethods().exists(_.getName == "$$")) {
+      "{{PWD}}"
+    } else if (Utils.isWindows) {
+      "%PWD%"
+    } else {
+      Environment.PWD.$()
+    }
+
+  test("Local jar URIs") {
+    val conf = new Configuration()
+    val sparkConf = new SparkConf()
+      .set(SPARK_JARS, Seq(SPARK))
+      .set(USER_CLASS_PATH_FIRST, true)
+      .set("spark.yarn.dist.jars", ADDED)
+    val env = new MutableHashMap[String, String]()
+    val args = new ClientArguments(Array("--jar", USER))
+
+    populateClasspath(args, conf, sparkConf, env)
+
+    val cp = env("CLASSPATH").split(":|;|<CPS>")
+    s"$SPARK,$USER,$ADDED".split(",").foreach({ entry =>
+      val uri = new URI(entry)
+      if (LOCAL_SCHEME.equals(uri.getScheme())) {
+        cp should contain (uri.getPath())
+      } else {
+        cp should not contain (uri.getPath())
+      }
+    })
+    cp should contain(PWD)
+    cp should contain (s"$PWD${Path.SEPARATOR}${LOCALIZED_CONF_DIR}")
+    cp should not contain (APP_JAR)
+  }
+
+  test("Jar path propagation through SparkConf") {
+    val conf = new Configuration()
+    val sparkConf = new SparkConf()
+      .set(SPARK_JARS, Seq(SPARK))
+      .set("spark.yarn.dist.jars", ADDED)
+    val client = createClient(sparkConf, args = Array("--jar", USER))
+    doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]),
+      any(classOf[Path]), anyShort(), anyBoolean(), any())
+
+    val tempDir = Utils.createTempDir()
+    try {
+      // Because we mocked "copyFileToRemote" above to avoid having to create 
fake local files,
+      // we need to create a fake config archive in the temp dir to avoid 
having
+      // prepareLocalResources throw an exception.
+      new FileOutputStream(new File(tempDir, LOCALIZED_CONF_ARCHIVE)).close()
+
+      client.prepareLocalResources(new Path(tempDir.getAbsolutePath()), Nil)
+      sparkConf.get(APP_JAR) should be (Some(USER))
+
+      // The non-local path should be propagated by name only, since it will 
end up in the app's
+      // staging dir.
+      val expected = ADDED.split(",")
+        .map(p => {
+          val uri = new URI(p)
+          if (LOCAL_SCHEME == uri.getScheme()) {
+            p
+          } else {
+            Option(uri.getFragment()).getOrElse(new File(p).getName())
+          }
+        })
+        .mkString(",")
+
+      sparkConf.get(SECONDARY_JARS) should be (Some(expected.split(",").toSeq))
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+
+  test("Cluster path translation") {
+    val conf = new Configuration()
+    val sparkConf = new SparkConf()
+      .set(SPARK_JARS, Seq("local:/localPath/spark.jar"))
+      .set(GATEWAY_ROOT_PATH, "/localPath")
+      .set(REPLACEMENT_ROOT_PATH, "/remotePath")
+
+    getClusterPath(sparkConf, "/localPath") should be ("/remotePath")
+    getClusterPath(sparkConf, "/localPath/1:/localPath/2") should be (
+      "/remotePath/1:/remotePath/2")
+
+    val env = new MutableHashMap[String, String]()
+    populateClasspath(null, conf, sparkConf, env, extraClassPath = 
Some("/localPath/my1.jar"))
+    val cp = classpath(env)
+    cp should contain ("/remotePath/spark.jar")
+    cp should contain ("/remotePath/my1.jar")
+  }
+
+  test("configuration and args propagate through 
createApplicationSubmissionContext") {
+    val conf = new Configuration()
+    // When parsing tags, duplicates and leading/trailing whitespace should be 
removed.
+    // Spaces between non-comma strings should be preserved as single tags. 
Empty strings may or
+    // may not be removed depending on the version of Hadoop being used.
+    val sparkConf = new SparkConf()
+      .set(APPLICATION_TAGS.key, ",tag1, dup,tag2 , ,multi word , dup")
+      .set(MAX_APP_ATTEMPTS, 42)
+      .set("spark.app.name", "foo-test-app")
+      .set(QUEUE_NAME, "staging-queue")
+    val args = new ClientArguments(Array())
+
+    val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
+    val getNewApplicationResponse = 
Records.newRecord(classOf[GetNewApplicationResponse])
+    val containerLaunchContext = 
Records.newRecord(classOf[ContainerLaunchContext])
+
+    val client = new Client(args, conf, sparkConf)
+    client.createApplicationSubmissionContext(
+      new YarnClientApplication(getNewApplicationResponse, appContext),
+      containerLaunchContext)
+
+    appContext.getApplicationName should be ("foo-test-app")
+    appContext.getQueue should be ("staging-queue")
+    appContext.getAMContainerSpec should be (containerLaunchContext)
+    appContext.getApplicationType should be ("SPARK")
+    
appContext.getClass.getMethods.filter(_.getName.equals("getApplicationTags")).foreach{
 method =>
+      val tags = method.invoke(appContext).asInstanceOf[java.util.Set[String]]
+      tags should contain allOf ("tag1", "dup", "tag2", "multi word")
+      tags.asScala.count(_.nonEmpty) should be (4)
+    }
+    appContext.getMaxAppAttempts should be (42)
+  }
+
+  test("spark.yarn.jars with multiple paths and globs") {
+    val libs = Utils.createTempDir()
+    val single = Utils.createTempDir()
+    val jar1 = TestUtils.createJarWithFiles(Map(), libs)
+    val jar2 = TestUtils.createJarWithFiles(Map(), libs)
+    val jar3 = TestUtils.createJarWithFiles(Map(), single)
+    val jar4 = TestUtils.createJarWithFiles(Map(), single)
+
+    val jarsConf = Seq(
+      s"${libs.getAbsolutePath()}/*",
+      jar3.getPath(),
+      s"local:${jar4.getPath()}",
+      s"local:${single.getAbsolutePath()}/*")
+
+    val sparkConf = new SparkConf().set(SPARK_JARS, jarsConf)
+    val client = createClient(sparkConf)
+
+    val tempDir = Utils.createTempDir()
+    client.prepareLocalResources(new Path(tempDir.getAbsolutePath()), Nil)
+
+    assert(sparkConf.get(SPARK_JARS) ===
+      Some(Seq(s"local:${jar4.getPath()}", 
s"local:${single.getAbsolutePath()}/*")))
+
+    verify(client).copyFileToRemote(any(classOf[Path]), meq(new 
Path(jar1.toURI())), anyShort(),
+      anyBoolean(), any())
+    verify(client).copyFileToRemote(any(classOf[Path]), meq(new 
Path(jar2.toURI())), anyShort(),
+      anyBoolean(), any())
+    verify(client).copyFileToRemote(any(classOf[Path]), meq(new 
Path(jar3.toURI())), anyShort(),
+      anyBoolean(), any())
+
+    val cp = classpath(client)
+    cp should contain (buildPath(PWD, LOCALIZED_LIB_DIR, "*"))
+    cp should not contain (jar3.getPath())
+    cp should contain (jar4.getPath())
+    cp should contain (buildPath(single.getAbsolutePath(), "*"))
+  }
+
+  test("distribute jars archive") {
+    val temp = Utils.createTempDir()
+    val archive = TestUtils.createJarWithFiles(Map(), temp)
+
+    val sparkConf = new SparkConf().set(SPARK_ARCHIVE, archive.getPath())
+    val client = createClient(sparkConf)
+    client.prepareLocalResources(new Path(temp.getAbsolutePath()), Nil)
+
+    verify(client).copyFileToRemote(any(classOf[Path]), meq(new 
Path(archive.toURI())), anyShort(),
+      anyBoolean(), any())
+    classpath(client) should contain (buildPath(PWD, LOCALIZED_LIB_DIR, "*"))
+
+    sparkConf.set(SPARK_ARCHIVE, LOCAL_SCHEME + ":" + archive.getPath())
+    intercept[IllegalArgumentException] {
+      client.prepareLocalResources(new Path(temp.getAbsolutePath()), Nil)
+    }
+  }
+
+  test("distribute archive multiple times") {
+    val libs = Utils.createTempDir()
+    // Create jars dir and RELEASE file to avoid IllegalStateException.
+    val jarsDir = new File(libs, "jars")
+    assert(jarsDir.mkdir())
+    new FileOutputStream(new File(libs, "RELEASE")).close()
+
+    val userLib1 = Utils.createTempDir()
+    val testJar = TestUtils.createJarWithFiles(Map(), userLib1)
+
+    // Case 1:  FILES_TO_DISTRIBUTE and ARCHIVES_TO_DISTRIBUTE can't have 
duplicate files
+    val sparkConf = new SparkConfWithEnv(Map("SPARK_HOME" -> 
libs.getAbsolutePath))
+      .set(FILES_TO_DISTRIBUTE, Seq(testJar.getPath))
+      .set(ARCHIVES_TO_DISTRIBUTE, Seq(testJar.getPath))
+
+    val client = createClient(sparkConf)
+    val tempDir = Utils.createTempDir()
+    intercept[IllegalArgumentException] {
+      client.prepareLocalResources(new Path(tempDir.getAbsolutePath()), Nil)
+    }
+
+    // Case 2: FILES_TO_DISTRIBUTE can't have duplicate files.
+    val sparkConfFiles = new SparkConfWithEnv(Map("SPARK_HOME" -> 
libs.getAbsolutePath))
+      .set(FILES_TO_DISTRIBUTE, Seq(testJar.getPath, testJar.getPath))
+
+    val clientFiles = createClient(sparkConfFiles)
+    val tempDirForFiles = Utils.createTempDir()
+    intercept[IllegalArgumentException] {
+      clientFiles.prepareLocalResources(new 
Path(tempDirForFiles.getAbsolutePath()), Nil)
+    }
+
+    // Case 3: ARCHIVES_TO_DISTRIBUTE can't have duplicate files.
+    val sparkConfArchives = new SparkConfWithEnv(Map("SPARK_HOME" -> 
libs.getAbsolutePath))
+      .set(ARCHIVES_TO_DISTRIBUTE, Seq(testJar.getPath, testJar.getPath))
+
+    val clientArchives = createClient(sparkConfArchives)
+    val tempDirForArchives = Utils.createTempDir()
+    intercept[IllegalArgumentException] {
+      clientArchives.prepareLocalResources(new 
Path(tempDirForArchives.getAbsolutePath()), Nil)
+    }
+
+    // Case 4: FILES_TO_DISTRIBUTE can have unique file.
+    val sparkConfFilesUniq = new SparkConfWithEnv(Map("SPARK_HOME" -> 
libs.getAbsolutePath))
+      .set(FILES_TO_DISTRIBUTE, Seq(testJar.getPath))
+
+    val clientFilesUniq = createClient(sparkConfFilesUniq)
+    val tempDirForFilesUniq = Utils.createTempDir()
+    clientFilesUniq.prepareLocalResources(new 
Path(tempDirForFilesUniq.getAbsolutePath()), Nil)
+
+    // Case 5: ARCHIVES_TO_DISTRIBUTE can have unique file.
+    val sparkConfArchivesUniq = new SparkConfWithEnv(Map("SPARK_HOME" -> 
libs.getAbsolutePath))
+      .set(ARCHIVES_TO_DISTRIBUTE, Seq(testJar.getPath))
+
+    val clientArchivesUniq = createClient(sparkConfArchivesUniq)
+    val tempDirArchivesUniq = Utils.createTempDir()
+    clientArchivesUniq.prepareLocalResources(new 
Path(tempDirArchivesUniq.getAbsolutePath()), Nil)
+
+  }
+
+  test("distribute local spark jars") {
+    val temp = Utils.createTempDir()
+    val jarsDir = new File(temp, "jars")
+    assert(jarsDir.mkdir())
+    val jar = TestUtils.createJarWithFiles(Map(), jarsDir)
+    new FileOutputStream(new File(temp, "RELEASE")).close()
+
+    val sparkConf = new SparkConfWithEnv(Map("SPARK_HOME" -> 
temp.getAbsolutePath()))
+    val client = createClient(sparkConf)
+    client.prepareLocalResources(new Path(temp.getAbsolutePath()), Nil)
+    classpath(client) should contain (buildPath(PWD, LOCALIZED_LIB_DIR, "*"))
+  }
+
+  test("ignore same name jars") {
+    val libs = Utils.createTempDir()
+    val jarsDir = new File(libs, "jars")
+    assert(jarsDir.mkdir())
+    new FileOutputStream(new File(libs, "RELEASE")).close()
+    val userLib1 = Utils.createTempDir()
+    val userLib2 = Utils.createTempDir()
+
+    val jar1 = TestUtils.createJarWithFiles(Map(), jarsDir)
+    val jar2 = TestUtils.createJarWithFiles(Map(), userLib1)
+    // Copy jar2 to jar3 with same name
+    val jar3 = {
+      val target = new File(userLib2, new File(jar2.toURI).getName)
+      val input = new FileInputStream(jar2.getPath)
+      val output = new FileOutputStream(target)
+      Utils.copyStream(input, output, closeStreams = true)
+      target.toURI.toURL
+    }
+
+    val sparkConf = new SparkConfWithEnv(Map("SPARK_HOME" -> 
libs.getAbsolutePath))
+      .set(JARS_TO_DISTRIBUTE, Seq(jar2.getPath, jar3.getPath))
+
+    val client = createClient(sparkConf)
+    val tempDir = Utils.createTempDir()
+    client.prepareLocalResources(new Path(tempDir.getAbsolutePath()), Nil)
+
+    // Only jar2 will be added to SECONDARY_JARS, jar3 which has the same name 
with jar2 will be
+    // ignored.
+    sparkConf.get(SECONDARY_JARS) should be (Some(Seq(new 
File(jar2.toURI).getName)))
+  }
+
+  object Fixtures {
+
+    val knownDefYarnAppCP: Seq[String] =
+      getFieldValue[Array[String], Seq[String]](classOf[YarnConfiguration],
+                                                
"DEFAULT_YARN_APPLICATION_CLASSPATH",
+                                                Seq[String]())(a => a.toSeq)
+
+
+    val knownDefMRAppCP: Seq[String] =
+      getFieldValue2[String, Array[String], Seq[String]](
+        classOf[MRJobConfig],
+        "DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH",
+        Seq[String]())(a => a.split(","))(a => a.toSeq)
+
+    val knownYARNAppCP = Some(Seq("/known/yarn/path"))
+
+    val knownMRAppCP = Some(Seq("/known/mr/path"))
+
+    val mapMRAppConf =
+      Map("mapreduce.application.classpath" -> 
knownMRAppCP.map(_.mkString(":")).get)
+
+    val mapYARNAppConf =
+      Map(YarnConfiguration.YARN_APPLICATION_CLASSPATH -> 
knownYARNAppCP.map(_.mkString(":")).get)
+
+    val mapAppConf = mapYARNAppConf ++ mapMRAppConf
+  }
+
+  def withAppConf(m: Map[String, String] = Map())(testCode: (Configuration) => 
Any) {
+    val conf = new Configuration
+    m.foreach { case (k, v) => conf.set(k, v, "ClientSpec") }
+    testCode(conf)
+  }
+
+  def newEnv: MutableHashMap[String, String] = MutableHashMap[String, String]()
+
+  def classpath(env: MutableHashMap[String, String]): Array[String] =
+    env(Environment.CLASSPATH.name).split(":|;|<CPS>")
+
+  def flatten(a: Option[Seq[String]], b: Option[Seq[String]]): Array[String] =
+    (a ++ b).flatten.toArray
+
+  def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => 
B)(mapTo: A => B): B = {
+    Try(clazz.getField(field))
+      .map(_.get(null).asInstanceOf[A])
+      .toOption
+      .map(mapTo)
+      .getOrElse(defaults)
+  }
+
+  def getFieldValue2[A: ClassTag, A1: ClassTag, B](
+        clazz: Class[_],
+        field: String,
+        defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = {
+    Try(clazz.getField(field)).map(_.get(null)).map {
+      case v: A => mapTo(v)
+      case v1: A1 => mapTo1(v1)
+      case _ => defaults
+    }.toOption.getOrElse(defaults)
+  }
+
+  private def createClient(
+      sparkConf: SparkConf,
+      conf: Configuration = new Configuration(),
+      args: Array[String] = Array()): Client = {
+    val clientArgs = new ClientArguments(args)
+    spy(new Client(clientArgs, conf, sparkConf))
+  }
+
+  private def classpath(client: Client): Array[String] = {
+    val env = new MutableHashMap[String, String]()
+    populateClasspath(null, client.hadoopConf, client.sparkConf, env)
+    classpath(env)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/81e5619c/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala
----------------------------------------------------------------------
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala
 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala
new file mode 100644
index 0000000..afb4b69
--- /dev/null
+++ 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+import org.scalatest.{BeforeAndAfterEach, Matchers}
+
+import org.apache.spark.SparkFunSuite
+
+class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with 
BeforeAndAfterEach {
+
+  private val yarnAllocatorSuite = new YarnAllocatorSuite
+  import yarnAllocatorSuite._
+
+  def createContainerRequest(nodes: Array[String]): ContainerRequest =
+    new ContainerRequest(containerResource, nodes, null, 
YarnSparkHadoopUtil.RM_REQUEST_PRIORITY)
+
+  override def beforeEach() {
+    yarnAllocatorSuite.beforeEach()
+  }
+
+  override def afterEach() {
+    yarnAllocatorSuite.afterEach()
+  }
+
+  test("allocate locality preferred containers with enough resource and no 
matched existed " +
+    "containers") {
+    // 1. All the locations of current containers cannot satisfy the new 
requirements
+    // 2. Current requested container number can fully satisfy the pending 
tasks.
+
+    val handler = createAllocator(2)
+    handler.updateResourceRequests()
+    handler.handleAllocatedContainers(Array(createContainer("host1"), 
createContainer("host2")))
+
+    val localities = 
handler.containerPlacementStrategy.localityOfRequestedContainers(
+      3, 15, Map("host3" -> 15, "host4" -> 15, "host5" -> 10),
+        handler.allocatedHostToContainersMap, Seq.empty)
+
+    assert(localities.map(_.nodes) === Array(
+      Array("host3", "host4", "host5"),
+      Array("host3", "host4", "host5"),
+      Array("host3", "host4")))
+  }
+
+  test("allocate locality preferred containers with enough resource and 
partially matched " +
+    "containers") {
+    // 1. Parts of current containers' locations can satisfy the new 
requirements
+    // 2. Current requested container number can fully satisfy the pending 
tasks.
+
+    val handler = createAllocator(3)
+    handler.updateResourceRequests()
+    handler.handleAllocatedContainers(Array(
+      createContainer("host1"),
+      createContainer("host1"),
+      createContainer("host2")
+    ))
+
+    val localities = 
handler.containerPlacementStrategy.localityOfRequestedContainers(
+      3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10),
+        handler.allocatedHostToContainersMap, Seq.empty)
+
+    assert(localities.map(_.nodes) ===
+      Array(null, Array("host2", "host3"), Array("host2", "host3")))
+  }
+
+  test("allocate locality preferred containers with limited resource and 
partially matched " +
+    "containers") {
+    // 1. Parts of current containers' locations can satisfy the new 
requirements
+    // 2. Current requested container number cannot fully satisfy the pending 
tasks.
+
+    val handler = createAllocator(3)
+    handler.updateResourceRequests()
+    handler.handleAllocatedContainers(Array(
+      createContainer("host1"),
+      createContainer("host1"),
+      createContainer("host2")
+    ))
+
+    val localities = 
handler.containerPlacementStrategy.localityOfRequestedContainers(
+      1, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10),
+        handler.allocatedHostToContainersMap, Seq.empty)
+
+    assert(localities.map(_.nodes) === Array(Array("host2", "host3")))
+  }
+
+  test("allocate locality preferred containers with fully matched containers") 
{
+    // Current containers' locations can fully satisfy the new requirements
+
+    val handler = createAllocator(5)
+    handler.updateResourceRequests()
+    handler.handleAllocatedContainers(Array(
+      createContainer("host1"),
+      createContainer("host1"),
+      createContainer("host2"),
+      createContainer("host2"),
+      createContainer("host3")
+    ))
+
+    val localities = 
handler.containerPlacementStrategy.localityOfRequestedContainers(
+      3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10),
+        handler.allocatedHostToContainersMap, Seq.empty)
+
+    assert(localities.map(_.nodes) === Array(null, null, null))
+  }
+
+  test("allocate containers with no locality preference") {
+    // Request new container without locality preference
+
+    val handler = createAllocator(2)
+    handler.updateResourceRequests()
+    handler.handleAllocatedContainers(Array(createContainer("host1"), 
createContainer("host2")))
+
+    val localities = 
handler.containerPlacementStrategy.localityOfRequestedContainers(
+      1, 0, Map.empty, handler.allocatedHostToContainersMap, Seq.empty)
+
+    assert(localities.map(_.nodes) === Array(null))
+  }
+
+  test("allocate locality preferred containers by considering the localities 
of pending requests") {
+    val handler = createAllocator(3)
+    handler.updateResourceRequests()
+    handler.handleAllocatedContainers(Array(
+      createContainer("host1"),
+      createContainer("host1"),
+      createContainer("host2")
+    ))
+
+    val pendingAllocationRequests = Seq(
+      createContainerRequest(Array("host2", "host3")),
+      createContainerRequest(Array("host1", "host4")))
+
+    val localities = 
handler.containerPlacementStrategy.localityOfRequestedContainers(
+      1, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10),
+        handler.allocatedHostToContainersMap, pendingAllocationRequests)
+
+    assert(localities.map(_.nodes) === Array(Array("host3")))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/81e5619c/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
new file mode 100644
index 0000000..994dc75
--- /dev/null
+++ 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.util.{Arrays, List => JList}
+
+import org.apache.hadoop.fs.CommonConfigurationKeysPublic
+import org.apache.hadoop.net.DNSToSwitchMapping
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.AMRMClient
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.mockito.Mockito._
+import org.scalatest.{BeforeAndAfterEach, Matchers}
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.deploy.yarn.YarnAllocator._
+import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.scheduler.SplitInfo
+import org.apache.spark.util.ManualClock
+
+class MockResolver extends DNSToSwitchMapping {
+
+  override def resolve(names: JList[String]): JList[String] = {
+    if (names.size > 0 && names.get(0) == "host3") Arrays.asList("/rack2")
+    else Arrays.asList("/rack1")
+  }
+
+  override def reloadCachedMappings() {}
+
+  def reloadCachedMappings(names: JList[String]) {}
+}
+
+class YarnAllocatorSuite extends SparkFunSuite with Matchers with 
BeforeAndAfterEach {
+  val conf = new YarnConfiguration()
+  conf.setClass(
+    CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY,
+    classOf[MockResolver], classOf[DNSToSwitchMapping])
+
+  val sparkConf = new SparkConf()
+  sparkConf.set("spark.driver.host", "localhost")
+  sparkConf.set("spark.driver.port", "4040")
+  sparkConf.set(SPARK_JARS, Seq("notarealjar.jar"))
+  sparkConf.set("spark.yarn.launchContainers", "false")
+
+  val appAttemptId = 
ApplicationAttemptId.newInstance(ApplicationId.newInstance(0, 0), 0)
+
+  // Resource returned by YARN.  YARN can give larger containers than 
requested, so give 6 cores
+  // instead of the 5 requested and 3 GB instead of the 2 requested.
+  val containerResource = Resource.newInstance(3072, 6)
+
+  var rmClient: AMRMClient[ContainerRequest] = _
+
+  var containerNum = 0
+
+  override def beforeEach() {
+    super.beforeEach()
+    rmClient = AMRMClient.createAMRMClient()
+    rmClient.init(conf)
+    rmClient.start()
+  }
+
+  override def afterEach() {
+    try {
+      rmClient.stop()
+    } finally {
+      super.afterEach()
+    }
+  }
+
+  class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, 
null) {
+    override def hashCode(): Int = 0
+    override def equals(other: Any): Boolean = false
+  }
+
+  def createAllocator(maxExecutors: Int = 5): YarnAllocator = {
+    val args = Array(
+      "--jar", "somejar.jar",
+      "--class", "SomeClass")
+    val sparkConfClone = sparkConf.clone()
+    sparkConfClone
+      .set("spark.executor.instances", maxExecutors.toString)
+      .set("spark.executor.cores", "5")
+      .set("spark.executor.memory", "2048")
+    new YarnAllocator(
+      "not used",
+      mock(classOf[RpcEndpointRef]),
+      conf,
+      sparkConfClone,
+      rmClient,
+      appAttemptId,
+      new SecurityManager(sparkConf),
+      Map())
+  }
+
+  def createContainer(host: String): Container = {
+    // When YARN 2.6+ is required, avoid deprecation by using version with 
long second arg
+    val containerId = ContainerId.newInstance(appAttemptId, containerNum)
+    containerNum += 1
+    val nodeId = NodeId.newInstance(host, 1000)
+    Container.newInstance(containerId, nodeId, "", containerResource, 
RM_REQUEST_PRIORITY, null)
+  }
+
+  test("single container allocated") {
+    // request a single container and receive it
+    val handler = createAllocator(1)
+    handler.updateResourceRequests()
+    handler.getNumExecutorsRunning should be (0)
+    handler.getPendingAllocate.size should be (1)
+
+    val container = createContainer("host1")
+    handler.handleAllocatedContainers(Array(container))
+
+    handler.getNumExecutorsRunning should be (1)
+    handler.allocatedContainerToHostMap.get(container.getId).get should be 
("host1")
+    handler.allocatedHostToContainersMap.get("host1").get should contain 
(container.getId)
+
+    val size = rmClient.getMatchingRequests(container.getPriority, "host1", 
containerResource).size
+    size should be (0)
+  }
+
+  test("container should not be created if requested number if met") {
+    // request a single container and receive it
+    val handler = createAllocator(1)
+    handler.updateResourceRequests()
+    handler.getNumExecutorsRunning should be (0)
+    handler.getPendingAllocate.size should be (1)
+
+    val container = createContainer("host1")
+    handler.handleAllocatedContainers(Array(container))
+
+    handler.getNumExecutorsRunning should be (1)
+    handler.allocatedContainerToHostMap.get(container.getId).get should be 
("host1")
+    handler.allocatedHostToContainersMap.get("host1").get should contain 
(container.getId)
+
+    val container2 = createContainer("host2")
+    handler.handleAllocatedContainers(Array(container2))
+    handler.getNumExecutorsRunning should be (1)
+  }
+
+  test("some containers allocated") {
+    // request a few containers and receive some of them
+    val handler = createAllocator(4)
+    handler.updateResourceRequests()
+    handler.getNumExecutorsRunning should be (0)
+    handler.getPendingAllocate.size should be (4)
+
+    val container1 = createContainer("host1")
+    val container2 = createContainer("host1")
+    val container3 = createContainer("host2")
+    handler.handleAllocatedContainers(Array(container1, container2, 
container3))
+
+    handler.getNumExecutorsRunning should be (3)
+    handler.allocatedContainerToHostMap.get(container1.getId).get should be 
("host1")
+    handler.allocatedContainerToHostMap.get(container2.getId).get should be 
("host1")
+    handler.allocatedContainerToHostMap.get(container3.getId).get should be 
("host2")
+    handler.allocatedHostToContainersMap.get("host1").get should contain 
(container1.getId)
+    handler.allocatedHostToContainersMap.get("host1").get should contain 
(container2.getId)
+    handler.allocatedHostToContainersMap.get("host2").get should contain 
(container3.getId)
+  }
+
+  test("receive more containers than requested") {
+    val handler = createAllocator(2)
+    handler.updateResourceRequests()
+    handler.getNumExecutorsRunning should be (0)
+    handler.getPendingAllocate.size should be (2)
+
+    val container1 = createContainer("host1")
+    val container2 = createContainer("host2")
+    val container3 = createContainer("host4")
+    handler.handleAllocatedContainers(Array(container1, container2, 
container3))
+
+    handler.getNumExecutorsRunning should be (2)
+    handler.allocatedContainerToHostMap.get(container1.getId).get should be 
("host1")
+    handler.allocatedContainerToHostMap.get(container2.getId).get should be 
("host2")
+    handler.allocatedContainerToHostMap.contains(container3.getId) should be 
(false)
+    handler.allocatedHostToContainersMap.get("host1").get should contain 
(container1.getId)
+    handler.allocatedHostToContainersMap.get("host2").get should contain 
(container2.getId)
+    handler.allocatedHostToContainersMap.contains("host4") should be (false)
+  }
+
+  test("decrease total requested executors") {
+    val handler = createAllocator(4)
+    handler.updateResourceRequests()
+    handler.getNumExecutorsRunning should be (0)
+    handler.getPendingAllocate.size should be (4)
+
+    handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty)
+    handler.updateResourceRequests()
+    handler.getPendingAllocate.size should be (3)
+
+    val container = createContainer("host1")
+    handler.handleAllocatedContainers(Array(container))
+
+    handler.getNumExecutorsRunning should be (1)
+    handler.allocatedContainerToHostMap.get(container.getId).get should be 
("host1")
+    handler.allocatedHostToContainersMap.get("host1").get should contain 
(container.getId)
+
+    handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty)
+    handler.updateResourceRequests()
+    handler.getPendingAllocate.size should be (1)
+  }
+
+  test("decrease total requested executors to less than currently running") {
+    val handler = createAllocator(4)
+    handler.updateResourceRequests()
+    handler.getNumExecutorsRunning should be (0)
+    handler.getPendingAllocate.size should be (4)
+
+    handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty)
+    handler.updateResourceRequests()
+    handler.getPendingAllocate.size should be (3)
+
+    val container1 = createContainer("host1")
+    val container2 = createContainer("host2")
+    handler.handleAllocatedContainers(Array(container1, container2))
+
+    handler.getNumExecutorsRunning should be (2)
+
+    handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty)
+    handler.updateResourceRequests()
+    handler.getPendingAllocate.size should be (0)
+    handler.getNumExecutorsRunning should be (2)
+  }
+
+  test("kill executors") {
+    val handler = createAllocator(4)
+    handler.updateResourceRequests()
+    handler.getNumExecutorsRunning should be (0)
+    handler.getPendingAllocate.size should be (4)
+
+    val container1 = createContainer("host1")
+    val container2 = createContainer("host2")
+    handler.handleAllocatedContainers(Array(container1, container2))
+
+    handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty)
+    handler.executorIdToContainer.keys.foreach { id => handler.killExecutor(id 
) }
+
+    val statuses = Seq(container1, container2).map { c =>
+      ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, 
"Finished", 0)
+    }
+    handler.updateResourceRequests()
+    handler.processCompletedContainers(statuses.toSeq)
+    handler.getNumExecutorsRunning should be (0)
+    handler.getPendingAllocate.size should be (1)
+  }
+
+  test("lost executor removed from backend") {
+    val handler = createAllocator(4)
+    handler.updateResourceRequests()
+    handler.getNumExecutorsRunning should be (0)
+    handler.getPendingAllocate.size should be (4)
+
+    val container1 = createContainer("host1")
+    val container2 = createContainer("host2")
+    handler.handleAllocatedContainers(Array(container1, container2))
+
+    handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map())
+
+    val statuses = Seq(container1, container2).map { c =>
+      ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, 
"Failed", -1)
+    }
+    handler.updateResourceRequests()
+    handler.processCompletedContainers(statuses.toSeq)
+    handler.updateResourceRequests()
+    handler.getNumExecutorsRunning should be (0)
+    handler.getPendingAllocate.size should be (2)
+    handler.getNumExecutorsFailed should be (2)
+    handler.getNumUnexpectedContainerRelease should be (2)
+  }
+
+  test("memory exceeded diagnostic regexes") {
+    val diagnostics =
+      "Container 
[pid=12465,containerID=container_1412887393566_0003_01_000002] is running " +
+        "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical 
memory used; " +
+        "5.8 GB of 4.2 GB virtual memory used. Killing container."
+    val vmemMsg = memLimitExceededLogMessage(diagnostics, 
VMEM_EXCEEDED_PATTERN)
+    val pmemMsg = memLimitExceededLogMessage(diagnostics, 
PMEM_EXCEEDED_PATTERN)
+    assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used."))
+    assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used."))
+  }
+
+  test("window based failure executor counting") {
+    sparkConf.set("spark.yarn.executor.failuresValidityInterval", "100s")
+    val handler = createAllocator(4)
+    val clock = new ManualClock(0L)
+    handler.setClock(clock)
+
+    handler.updateResourceRequests()
+    handler.getNumExecutorsRunning should be (0)
+    handler.getPendingAllocate.size should be (4)
+
+    val containers = Seq(
+      createContainer("host1"),
+      createContainer("host2"),
+      createContainer("host3"),
+      createContainer("host4")
+    )
+    handler.handleAllocatedContainers(containers)
+
+    val failedStatuses = containers.map { c =>
+      ContainerStatus.newInstance(c.getId, ContainerState.COMPLETE, "Failed", 
-1)
+    }
+
+    handler.getNumExecutorsFailed should be (0)
+
+    clock.advance(100 * 1000L)
+    handler.processCompletedContainers(failedStatuses.slice(0, 1))
+    handler.getNumExecutorsFailed should be (1)
+
+    clock.advance(101 * 1000L)
+    handler.getNumExecutorsFailed should be (0)
+
+    handler.processCompletedContainers(failedStatuses.slice(1, 3))
+    handler.getNumExecutorsFailed should be (2)
+
+    clock.advance(50 * 1000L)
+    handler.processCompletedContainers(failedStatuses.slice(3, 4))
+    handler.getNumExecutorsFailed should be (3)
+
+    clock.advance(51 * 1000L)
+    handler.getNumExecutorsFailed should be (1)
+
+    clock.advance(50 * 1000L)
+    handler.getNumExecutorsFailed should be (0)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/81e5619c/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
----------------------------------------------------------------------
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
new file mode 100644
index 0000000..99fb58a
--- /dev/null
+++ 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -0,0 +1,493 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.File
+import java.net.URL
+import java.nio.charset.StandardCharsets
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.mutable
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import com.google.common.io.{ByteStreams, Files}
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.scalatest.Matchers
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.Logging
+import org.apache.spark.launcher._
+import org.apache.spark.scheduler.{SparkListener, 
SparkListenerApplicationStart,
+  SparkListenerExecutorAdded}
+import org.apache.spark.scheduler.cluster.ExecutorInfo
+import org.apache.spark.tags.ExtendedYarnTest
+import org.apache.spark.util.Utils
+
+/**
+ * Integration tests for YARN; these tests use a mini Yarn cluster to run 
Spark-on-YARN
+ * applications, and require the Spark assembly to be built before they can be 
successfully
+ * run.
+ */
+@ExtendedYarnTest
+class YarnClusterSuite extends BaseYarnClusterSuite {
+
+  override def newYarnConfig(): YarnConfiguration = new YarnConfiguration()
+
+  private val TEST_PYFILE = """
+    |import mod1, mod2
+    |import sys
+    |from operator import add
+    |
+    |from pyspark import SparkConf , SparkContext
+    |if __name__ == "__main__":
+    |    if len(sys.argv) != 2:
+    |        print >> sys.stderr, "Usage: test.py [result file]"
+    |        exit(-1)
+    |    sc = SparkContext(conf=SparkConf())
+    |    status = open(sys.argv[1],'w')
+    |    result = "failure"
+    |    rdd = sc.parallelize(range(10)).map(lambda x: x * mod1.func() * 
mod2.func())
+    |    cnt = rdd.count()
+    |    if cnt == 10:
+    |        result = "success"
+    |    status.write(result)
+    |    status.close()
+    |    sc.stop()
+    """.stripMargin
+
+  private val TEST_PYMODULE = """
+    |def func():
+    |    return 42
+    """.stripMargin
+
+  test("run Spark in yarn-client mode") {
+    testBasicYarnApp(true)
+  }
+
+  test("run Spark in yarn-cluster mode") {
+    testBasicYarnApp(false)
+  }
+
+  test("run Spark in yarn-client mode with different configurations") {
+    testBasicYarnApp(true,
+      Map(
+        "spark.driver.memory" -> "512m",
+        "spark.executor.cores" -> "1",
+        "spark.executor.memory" -> "512m",
+        "spark.executor.instances" -> "2"
+      ))
+  }
+
+  test("run Spark in yarn-cluster mode with different configurations") {
+    testBasicYarnApp(false,
+      Map(
+        "spark.driver.memory" -> "512m",
+        "spark.driver.cores" -> "1",
+        "spark.executor.cores" -> "1",
+        "spark.executor.memory" -> "512m",
+        "spark.executor.instances" -> "2"
+      ))
+  }
+
+  test("run Spark in yarn-cluster mode with using SparkHadoopUtil.conf") {
+    testYarnAppUseSparkHadoopUtilConf()
+  }
+
+  test("run Spark in yarn-client mode with additional jar") {
+    testWithAddJar(true)
+  }
+
+  test("run Spark in yarn-cluster mode with additional jar") {
+    testWithAddJar(false)
+  }
+
+  test("run Spark in yarn-cluster mode unsuccessfully") {
+    // Don't provide arguments so the driver will fail.
+    val finalState = runSpark(false, mainClassName(YarnClusterDriver.getClass))
+    finalState should be (SparkAppHandle.State.FAILED)
+  }
+
+  test("run Spark in yarn-cluster mode failure after sc initialized") {
+    val finalState = runSpark(false, 
mainClassName(YarnClusterDriverWithFailure.getClass))
+    finalState should be (SparkAppHandle.State.FAILED)
+  }
+
+  test("run Python application in yarn-client mode") {
+    testPySpark(true)
+  }
+
+  test("run Python application in yarn-cluster mode") {
+    testPySpark(false)
+  }
+
+  test("run Python application in yarn-cluster mode using " +
+    " spark.yarn.appMasterEnv to override local envvar") {
+    testPySpark(
+      clientMode = false,
+      extraConf = Map(
+        "spark.yarn.appMasterEnv.PYSPARK_DRIVER_PYTHON"
+          -> sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python"),
+        "spark.yarn.appMasterEnv.PYSPARK_PYTHON"
+          -> sys.env.getOrElse("PYSPARK_PYTHON", "python")),
+      extraEnv = Map(
+        "PYSPARK_DRIVER_PYTHON" -> "not python",
+        "PYSPARK_PYTHON" -> "not python"))
+  }
+
+  test("user class path first in client mode") {
+    testUseClassPathFirst(true)
+  }
+
+  test("user class path first in cluster mode") {
+    testUseClassPathFirst(false)
+  }
+
+  test("monitor app using launcher library") {
+    val env = new JHashMap[String, String]()
+    env.put("YARN_CONF_DIR", hadoopConfDir.getAbsolutePath())
+
+    val propsFile = createConfFile()
+    val handle = new SparkLauncher(env)
+      .setSparkHome(sys.props("spark.test.home"))
+      .setConf("spark.ui.enabled", "false")
+      .setPropertiesFile(propsFile)
+      .setMaster("yarn")
+      .setDeployMode("client")
+      .setAppResource(SparkLauncher.NO_RESOURCE)
+      .setMainClass(mainClassName(YarnLauncherTestApp.getClass))
+      .startApplication()
+
+    try {
+      eventually(timeout(30 seconds), interval(100 millis)) {
+        handle.getState() should be (SparkAppHandle.State.RUNNING)
+      }
+
+      handle.getAppId() should not be (null)
+      handle.getAppId() should startWith ("application_")
+      handle.stop()
+
+      eventually(timeout(30 seconds), interval(100 millis)) {
+        handle.getState() should be (SparkAppHandle.State.KILLED)
+      }
+    } finally {
+      handle.kill()
+    }
+  }
+
+  test("timeout to get SparkContext in cluster mode triggers failure") {
+    val timeout = 2000
+    val finalState = runSpark(false, 
mainClassName(SparkContextTimeoutApp.getClass),
+      appArgs = Seq((timeout * 4).toString),
+      extraConf = Map(AM_MAX_WAIT_TIME.key -> timeout.toString))
+    finalState should be (SparkAppHandle.State.FAILED)
+  }
+
+  private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] 
= Map()): Unit = {
+    val result = File.createTempFile("result", null, tempDir)
+    val finalState = runSpark(clientMode, 
mainClassName(YarnClusterDriver.getClass),
+      appArgs = Seq(result.getAbsolutePath()),
+      extraConf = conf)
+    checkResult(finalState, result)
+  }
+
+  private def testYarnAppUseSparkHadoopUtilConf(): Unit = {
+    val result = File.createTempFile("result", null, tempDir)
+    val finalState = runSpark(false,
+      mainClassName(YarnClusterDriverUseSparkHadoopUtilConf.getClass),
+      appArgs = Seq("key=value", result.getAbsolutePath()),
+      extraConf = Map("spark.hadoop.key" -> "value"))
+    checkResult(finalState, result)
+  }
+
+  private def testWithAddJar(clientMode: Boolean): Unit = {
+    val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> 
"ORIGINAL"), tempDir)
+    val driverResult = File.createTempFile("driver", null, tempDir)
+    val executorResult = File.createTempFile("executor", null, tempDir)
+    val finalState = runSpark(clientMode, 
mainClassName(YarnClasspathTest.getClass),
+      appArgs = Seq(driverResult.getAbsolutePath(), 
executorResult.getAbsolutePath()),
+      extraClassPath = Seq(originalJar.getPath()),
+      extraJars = Seq("local:" + originalJar.getPath()))
+    checkResult(finalState, driverResult, "ORIGINAL")
+    checkResult(finalState, executorResult, "ORIGINAL")
+  }
+
+  private def testPySpark(
+      clientMode: Boolean,
+      extraConf: Map[String, String] = Map(),
+      extraEnv: Map[String, String] = Map()): Unit = {
+    val primaryPyFile = new File(tempDir, "test.py")
+    Files.write(TEST_PYFILE, primaryPyFile, StandardCharsets.UTF_8)
+
+    // When running tests, let's not assume the user has built the assembly 
module, which also
+    // creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH 
to point at the
+    // needed locations.
+    val sparkHome = sys.props("spark.test.home")
+    val pythonPath = Seq(
+        s"$sparkHome/python/lib/py4j-0.10.4-src.zip",
+        s"$sparkHome/python")
+    val extraEnvVars = Map(
+      "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + 
_).mkString(File.pathSeparator),
+      "PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) ++ extraEnv
+
+    val moduleDir =
+      if (clientMode) {
+        // In client-mode, .py files added with --py-files are not visible in 
the driver.
+        // This is something that the launcher library would have to handle.
+        tempDir
+      } else {
+        val subdir = new File(tempDir, "pyModules")
+        subdir.mkdir()
+        subdir
+      }
+    val pyModule = new File(moduleDir, "mod1.py")
+    Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8)
+
+    val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> 
TEST_PYMODULE), moduleDir)
+    val pyFiles = Seq(pyModule.getAbsolutePath(), 
mod2Archive.getPath()).mkString(",")
+    val result = File.createTempFile("result", null, tempDir)
+
+    val finalState = runSpark(clientMode, primaryPyFile.getAbsolutePath(),
+      sparkArgs = Seq("--py-files" -> pyFiles),
+      appArgs = Seq(result.getAbsolutePath()),
+      extraEnv = extraEnvVars,
+      extraConf = extraConf)
+    checkResult(finalState, result)
+  }
+
+  private def testUseClassPathFirst(clientMode: Boolean): Unit = {
+    // Create a jar file that contains a different version of "test.resource".
+    val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> 
"ORIGINAL"), tempDir)
+    val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> 
"OVERRIDDEN"), tempDir)
+    val driverResult = File.createTempFile("driver", null, tempDir)
+    val executorResult = File.createTempFile("executor", null, tempDir)
+    val finalState = runSpark(clientMode, 
mainClassName(YarnClasspathTest.getClass),
+      appArgs = Seq(driverResult.getAbsolutePath(), 
executorResult.getAbsolutePath()),
+      extraClassPath = Seq(originalJar.getPath()),
+      extraJars = Seq("local:" + userJar.getPath()),
+      extraConf = Map(
+        "spark.driver.userClassPathFirst" -> "true",
+        "spark.executor.userClassPathFirst" -> "true"))
+    checkResult(finalState, driverResult, "OVERRIDDEN")
+    checkResult(finalState, executorResult, "OVERRIDDEN")
+  }
+
+}
+
+private[spark] class SaveExecutorInfo extends SparkListener {
+  val addedExecutorInfos = mutable.Map[String, ExecutorInfo]()
+  var driverLogs: Option[collection.Map[String, String]] = None
+
+  override def onExecutorAdded(executor: SparkListenerExecutorAdded) {
+    addedExecutorInfos(executor.executorId) = executor.executorInfo
+  }
+
+  override def onApplicationStart(appStart: SparkListenerApplicationStart): 
Unit = {
+    driverLogs = appStart.driverLogs
+  }
+}
+
+private object YarnClusterDriverWithFailure extends Logging with Matchers {
+  def main(args: Array[String]): Unit = {
+    val sc = new SparkContext(new SparkConf()
+      .set("spark.extraListeners", classOf[SaveExecutorInfo].getName)
+      .setAppName("yarn test with failure"))
+
+    throw new Exception("exception after sc initialized")
+  }
+}
+
+private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with 
Matchers {
+  def main(args: Array[String]): Unit = {
+    if (args.length != 2) {
+      // scalastyle:off println
+      System.err.println(
+        s"""
+        |Invalid command line: ${args.mkString(" ")}
+        |
+        |Usage: YarnClusterDriverUseSparkHadoopUtilConf [hadoopConfKey=value] 
[result file]
+        """.stripMargin)
+      // scalastyle:on println
+      System.exit(1)
+    }
+
+    val sc = new SparkContext(new SparkConf()
+      .set("spark.extraListeners", classOf[SaveExecutorInfo].getName)
+      .setAppName("yarn test using SparkHadoopUtil's conf"))
+
+    val kv = args(0).split("=")
+    val status = new File(args(1))
+    var result = "failure"
+    try {
+      SparkHadoopUtil.get.conf.get(kv(0)) should be (kv(1))
+      result = "success"
+    } finally {
+      Files.write(result, status, StandardCharsets.UTF_8)
+      sc.stop()
+    }
+  }
+}
+
+private object YarnClusterDriver extends Logging with Matchers {
+
+  val WAIT_TIMEOUT_MILLIS = 10000
+
+  def main(args: Array[String]): Unit = {
+    if (args.length != 1) {
+      // scalastyle:off println
+      System.err.println(
+        s"""
+        |Invalid command line: ${args.mkString(" ")}
+        |
+        |Usage: YarnClusterDriver [result file]
+        """.stripMargin)
+      // scalastyle:on println
+      System.exit(1)
+    }
+
+    val sc = new SparkContext(new SparkConf()
+      .set("spark.extraListeners", classOf[SaveExecutorInfo].getName)
+      .setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and 
$dollarSigns"))
+    val conf = sc.getConf
+    val status = new File(args(0))
+    var result = "failure"
+    try {
+      val data = sc.parallelize(1 to 4, 4).collect().toSet
+      sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+      data should be (Set(1, 2, 3, 4))
+      result = "success"
+
+      // Verify that the config archive is correctly placed in the classpath 
of all containers.
+      val confFile = "/" + Client.SPARK_CONF_FILE
+      assert(getClass().getResource(confFile) != null)
+      val configFromExecutors = sc.parallelize(1 to 4, 4)
+        .map { _ => 
Option(getClass().getResource(confFile)).map(_.toString).orNull }
+        .collect()
+      assert(configFromExecutors.find(_ == null) === None)
+    } finally {
+      Files.write(result, status, StandardCharsets.UTF_8)
+      sc.stop()
+    }
+
+    // verify log urls are present
+    val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo]
+    assert(listeners.size === 1)
+    val listener = listeners(0)
+    val executorInfos = listener.addedExecutorInfos.values
+    assert(executorInfos.nonEmpty)
+    executorInfos.foreach { info =>
+      assert(info.logUrlMap.nonEmpty)
+    }
+
+    // If we are running in yarn-cluster mode, verify that driver logs links 
and present and are
+    // in the expected format.
+    if (conf.get("spark.submit.deployMode") == "cluster") {
+      assert(listener.driverLogs.nonEmpty)
+      val driverLogs = listener.driverLogs.get
+      assert(driverLogs.size === 2)
+      assert(driverLogs.contains("stderr"))
+      assert(driverLogs.contains("stdout"))
+      val urlStr = driverLogs("stderr")
+      // Ensure that this is a valid URL, else this will throw an exception
+      new URL(urlStr)
+      val containerId = YarnSparkHadoopUtil.get.getContainerId
+      val user = Utils.getCurrentUserName()
+      
assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096"))
+    }
+  }
+
+}
+
+private object YarnClasspathTest extends Logging {
+  def error(m: String, ex: Throwable = null): Unit = {
+    logError(m, ex)
+    // scalastyle:off println
+    System.out.println(m)
+    if (ex != null) {
+      ex.printStackTrace(System.out)
+    }
+    // scalastyle:on println
+  }
+
+  def main(args: Array[String]): Unit = {
+    if (args.length != 2) {
+      error(
+        s"""
+        |Invalid command line: ${args.mkString(" ")}
+        |
+        |Usage: YarnClasspathTest [driver result file] [executor result file]
+        """.stripMargin)
+      // scalastyle:on println
+    }
+
+    readResource(args(0))
+    val sc = new SparkContext(new SparkConf())
+    try {
+      sc.parallelize(Seq(1)).foreach { x => readResource(args(1)) }
+    } finally {
+      sc.stop()
+    }
+  }
+
+  private def readResource(resultPath: String): Unit = {
+    var result = "failure"
+    try {
+      val ccl = Thread.currentThread().getContextClassLoader()
+      val resource = ccl.getResourceAsStream("test.resource")
+      val bytes = ByteStreams.toByteArray(resource)
+      result = new String(bytes, 0, bytes.length, StandardCharsets.UTF_8)
+    } catch {
+      case t: Throwable =>
+        error(s"loading test.resource to $resultPath", t)
+    } finally {
+      Files.write(result, new File(resultPath), StandardCharsets.UTF_8)
+    }
+  }
+
+}
+
+private object YarnLauncherTestApp {
+
+  def main(args: Array[String]): Unit = {
+    // Do not stop the application; the test will stop it using the launcher 
lib. Just run a task
+    // that will prevent the process from exiting.
+    val sc = new SparkContext(new SparkConf())
+    sc.parallelize(Seq(1)).foreach { i =>
+      this.synchronized {
+        wait()
+      }
+    }
+  }
+
+}
+
+/**
+ * Used to test code in the AM that detects the SparkContext instance. Expects 
a single argument
+ * with the duration to sleep for, in ms.
+ */
+private object SparkContextTimeoutApp {
+
+  def main(args: Array[String]): Unit = {
+    val Array(sleepTime) = args
+    Thread.sleep(java.lang.Long.parseLong(sleepTime))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/81e5619c/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
----------------------------------------------------------------------
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
new file mode 100644
index 0000000..950ebd9
--- /dev/null
+++ 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
@@ -0,0 +1,112 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.deploy.yarn
+
+import java.io.File
+import java.nio.charset.StandardCharsets
+
+import com.google.common.io.Files
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.scalatest.Matchers
+
+import org.apache.spark._
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.ShuffleTestAccessor
+import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor}
+import org.apache.spark.tags.ExtendedYarnTest
+
+/**
+ * Integration test for the external shuffle service with a yarn mini-cluster
+ */
+@ExtendedYarnTest
+class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite {
+
+  override def newYarnConfig(): YarnConfiguration = {
+    val yarnConfig = new YarnConfiguration()
+    yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle")
+    
yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"),
+      classOf[YarnShuffleService].getCanonicalName)
+    yarnConfig.set("spark.shuffle.service.port", "0")
+    yarnConfig
+  }
+
+  test("external shuffle service") {
+    val shuffleServicePort = YarnTestAccessor.getShuffleServicePort
+    val shuffleService = YarnTestAccessor.getShuffleServiceInstance
+
+    val registeredExecFile = 
YarnTestAccessor.getRegisteredExecutorFile(shuffleService)
+
+    logInfo("Shuffle service port = " + shuffleServicePort)
+    val result = File.createTempFile("result", null, tempDir)
+    val finalState = runSpark(
+      false,
+      mainClassName(YarnExternalShuffleDriver.getClass),
+      appArgs = Seq(result.getAbsolutePath(), 
registeredExecFile.getAbsolutePath),
+      extraConf = Map(
+        "spark.shuffle.service.enabled" -> "true",
+        "spark.shuffle.service.port" -> shuffleServicePort.toString
+      )
+    )
+    checkResult(finalState, result)
+    assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists())
+  }
+}
+
+private object YarnExternalShuffleDriver extends Logging with Matchers {
+
+  val WAIT_TIMEOUT_MILLIS = 10000
+
+  def main(args: Array[String]): Unit = {
+    if (args.length != 2) {
+      // scalastyle:off println
+      System.err.println(
+        s"""
+        |Invalid command line: ${args.mkString(" ")}
+        |
+        |Usage: ExternalShuffleDriver [result file] [registered exec file]
+        """.stripMargin)
+      // scalastyle:on println
+      System.exit(1)
+    }
+
+    val sc = new SparkContext(new SparkConf()
+      .setAppName("External Shuffle Test"))
+    val conf = sc.getConf
+    val status = new File(args(0))
+    val registeredExecFile = new File(args(1))
+    logInfo("shuffle service executor file = " + registeredExecFile)
+    var result = "failure"
+    val execStateCopy = new File(registeredExecFile.getAbsolutePath + "_dup")
+    try {
+      val data = sc.parallelize(0 until 100, 10).map { x => (x % 10) -> x 
}.reduceByKey{ _ + _ }.
+        collect().toSet
+      sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+      data should be ((0 until 10).map{x => x -> (x * 10 + 450)}.toSet)
+      result = "success"
+      // only one process can open a leveldb file at a time, so we copy the 
files
+      FileUtils.copyDirectory(registeredExecFile, execStateCopy)
+      
assert(!ShuffleTestAccessor.reloadRegisteredExecutors(execStateCopy).isEmpty)
+    } finally {
+      sc.stop()
+      FileUtils.deleteDirectory(execStateCopy)
+      Files.write(result, status, StandardCharsets.UTF_8)
+    }
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to