This is an automated email from the ASF dual-hosted git repository.

irashid pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 52838e7  [SPARK-13704][CORE][YARN] Reduce rack resolution time
52838e7 is described below

commit 52838e74afdd58a7c09707284b4e0232dc01ef26
Author: LantaoJin <[email protected]>
AuthorDate: Mon Apr 8 10:47:06 2019 -0500

    [SPARK-13704][CORE][YARN] Reduce rack resolution time
    
    ## What changes were proposed in this pull request?
    
    When you submit a stage on a large cluster, rack resolving takes a long 
time when initializing TaskSetManager because a script is invoked to resolve 
the rack of each host, one by one.
    Based on current implementation, it takes 30~40 seconds to resolve the 
racks in our 5000 nodes' cluster. After applied the patch, it decreased to less 
than 15 seconds.
    
    YARN-9332 has added an interface to handle multiple hosts in one invocation 
to save time. But before upgrading to the newest Hadoop, we could construct the 
same tool in Spark to resolve this issue.
    
    ## How was this patch tested?
    
    UT and manually testing on a 5000 node cluster.
    
    Closes #24245 from squito/SPARK-13704_update.
    
    Lead-authored-by: LantaoJin <[email protected]>
    Co-authored-by: Imran Rashid <[email protected]>
    Signed-off-by: Imran Rashid <[email protected]>
---
 .../apache/spark/scheduler/TaskSchedulerImpl.scala | 30 ++++++--
 .../apache/spark/scheduler/TaskSetManager.scala    | 33 ++++++--
 .../spark/scheduler/TaskSetManagerSuite.scala      | 76 +++++++++++++++++--
 ...calityPreferredContainerPlacementStrategy.scala |  5 +-
 .../spark/deploy/yarn/SparkRackResolver.scala      | 88 ++++++++++++++++++++--
 .../apache/spark/deploy/yarn/YarnAllocator.scala   |  2 +-
 .../apache/spark/deploy/yarn/YarnRMClient.scala    |  2 +-
 .../spark/scheduler/cluster/YarnScheduler.scala    | 20 ++---
 .../spark/deploy/yarn/YarnAllocatorSuite.scala     |  5 +-
 9 files changed, 219 insertions(+), 42 deletions(-)

diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index bffa1ff..e401c39 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -158,6 +158,8 @@ private[spark] class TaskSchedulerImpl(
 
   private[scheduler] var barrierCoordinator: RpcEndpoint = null
 
+  protected val defaultRackValue: Option[String] = None
+
   private def maybeInitBarrierCoordinator(): Unit = {
     if (barrierCoordinator == null) {
       barrierCoordinator = new BarrierCoordinator(barrierSyncTimeout, 
sc.listenerBus,
@@ -394,9 +396,10 @@ private[spark] class TaskSchedulerImpl(
         executorIdToRunningTaskIds(o.executorId) = HashSet[Long]()
         newExecAvail = true
       }
-      for (rack <- getRackForHost(o.host)) {
-        hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += o.host
-      }
+    }
+    val hosts = offers.map(_.host).toSet.toSeq
+    for ((host, Some(rack)) <- hosts.zip(getRacksForHosts(hosts))) {
+      hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += host
     }
 
     // Before making any offers, remove any nodes from the blacklist whose 
blacklist has expired. Do
@@ -830,8 +833,25 @@ private[spark] class TaskSchedulerImpl(
     blacklistTrackerOpt.map(_.nodeBlacklist()).getOrElse(Set.empty)
   }
 
-  // By default, rack is unknown
-  def getRackForHost(value: String): Option[String] = None
+  /**
+   * Get the rack for one host.
+   *
+   * Note that [[getRacksForHosts]] should be preferred when possible as that 
can be much
+   * more efficient.
+   */
+  def getRackForHost(host: String): Option[String] = {
+    getRacksForHosts(Seq(host)).head
+  }
+
+  /**
+   * Get racks for multiple hosts.
+   *
+   * The returned Sequence will be the same length as the hosts argument and 
can be zipped
+   * together with the hosts argument.
+   */
+  def getRacksForHosts(hosts: Seq[String]): Seq[Option[String]] = {
+    hosts.map(_ => defaultRackValue)
+  }
 
   private def waitBackendReady(): Unit = {
     if (backend.isReady) {
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 3977c0b..1444220 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -186,8 +186,24 @@ private[spark] class TaskSetManager(
 
   // Add all our tasks to the pending lists. We do this in reverse order
   // of task index so that tasks with low indices get launched first.
-  for (i <- (0 until numTasks).reverse) {
-    addPendingTask(i)
+  addPendingTasks()
+
+  private def addPendingTasks(): Unit = {
+    val (_, duration) = Utils.timeTakenMs {
+      for (i <- (0 until numTasks).reverse) {
+        addPendingTask(i, resolveRacks = false)
+      }
+      // Resolve the rack for each host. This can be slow, so de-dupe the list 
of hosts,
+      // and assign the rack to all relevant task indices.
+      val (hosts, indicesForHosts) = pendingTasksForHost.toSeq.unzip
+      val racks = sched.getRacksForHosts(hosts)
+      racks.zip(indicesForHosts).foreach {
+        case (Some(rack), indices) =>
+          pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer) ++= 
indices
+        case (None, _) => // no rack, nothing to do
+      }
+    }
+    logDebug(s"Adding pending tasks took $duration ms")
   }
 
   /**
@@ -214,7 +230,9 @@ private[spark] class TaskSetManager(
   private[scheduler] var emittedTaskSizeWarning = false
 
   /** Add a task to all the pending-task lists that it should be on. */
-  private[spark] def addPendingTask(index: Int) {
+  private[spark] def addPendingTask(
+      index: Int,
+      resolveRacks: Boolean = true): Unit = {
     for (loc <- tasks(index).preferredLocations) {
       loc match {
         case e: ExecutorCacheTaskLocation =>
@@ -234,8 +252,11 @@ private[spark] class TaskSetManager(
         case _ =>
       }
       pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer) += index
-      for (rack <- sched.getRackForHost(loc.host)) {
-        pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer) += index
+
+      if (resolveRacks) {
+        sched.getRackForHost(loc.host).foreach { rack =>
+          pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer) += index
+        }
       }
     }
 
@@ -331,7 +352,7 @@ private[spark] class TaskSetManager(
         val executors = prefs.flatMap(_ match {
           case e: ExecutorCacheTaskLocation => Some(e.executorId)
           case _ => None
-        });
+        })
         if (executors.contains(execId)) {
           speculatableTasks -= index
           return Some((index, TaskLocality.PROCESS_LOCAL))
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index ad03194..79160d0 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -22,8 +22,8 @@ import java.util.{Properties, Random}
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
-import org.mockito.ArgumentMatchers.{any, anyInt, anyString}
-import org.mockito.Mockito.{mock, never, spy, times, verify, when}
+import org.mockito.ArgumentMatchers.{any, anyBoolean, anyInt, anyString}
+import org.mockito.Mockito._
 import org.mockito.invocation.InvocationOnMock
 
 import org.apache.spark._
@@ -68,17 +68,27 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: 
FakeTaskScheduler)
 // Get the rack for a given host
 object FakeRackUtil {
   private val hostToRack = new mutable.HashMap[String, String]()
+  var numBatchInvocation = 0
+  var numSingleHostInvocation = 0
 
   def cleanUp() {
     hostToRack.clear()
+    numBatchInvocation = 0
+    numSingleHostInvocation = 0
   }
 
   def assignHostToRack(host: String, rack: String) {
     hostToRack(host) = rack
   }
 
-  def getRackForHost(host: String): Option[String] = {
-    hostToRack.get(host)
+  def getRacksForHosts(hosts: Seq[String]): Seq[Option[String]] = {
+    assert(hosts.toSet.size == hosts.size) // no dups in hosts
+    if (hosts.nonEmpty && hosts.length != 1) {
+      numBatchInvocation += 1
+    } else if (hosts.length == 1) {
+      numSingleHostInvocation += 1
+    }
+    hosts.map(hostToRack.get(_))
   }
 }
 
@@ -98,6 +108,9 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: 
(String, String)* /* ex
   val speculativeTasks = new ArrayBuffer[Int]
 
   val executors = new mutable.HashMap[String, String]
+
+  // this must be initialized before addExecutor
+  override val defaultRackValue: Option[String] = Some("default")
   for ((execId, host) <- liveExecutors) {
     addExecutor(execId, host)
   }
@@ -143,8 +156,9 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: 
(String, String)* /* ex
     }
   }
 
-
-  override def getRackForHost(value: String): Option[String] = 
FakeRackUtil.getRackForHost(value)
+  override def getRacksForHosts(hosts: Seq[String]): Seq[Option[String]] = {
+    FakeRackUtil.getRacksForHosts(hosts)
+  }
 }
 
 /**
@@ -1311,7 +1325,7 @@ class TaskSetManagerSuite extends SparkFunSuite with 
LocalSparkContext with Logg
     val taskDesc = taskSetManagerSpy.resourceOffer(exec, host, 
TaskLocality.ANY)
 
     // Assert the task has been black listed on the executor it was last 
executed on.
-    when(taskSetManagerSpy.addPendingTask(anyInt())).thenAnswer(
+    when(taskSetManagerSpy.addPendingTask(anyInt(), anyBoolean())).thenAnswer(
       (invocationOnMock: InvocationOnMock) => {
         val task: Int = invocationOnMock.getArgument(0)
         assert(taskSetManager.taskSetBlacklistHelperOpt.get.
@@ -1323,7 +1337,7 @@ class TaskSetManagerSuite extends SparkFunSuite with 
LocalSparkContext with Logg
     val e = new ExceptionFailure("a", "b", Array(), "c", None)
     taskSetManagerSpy.handleFailedTask(taskDesc.get.taskId, TaskState.FAILED, 
e)
 
-    verify(taskSetManagerSpy, times(1)).addPendingTask(anyInt())
+    verify(taskSetManagerSpy, times(1)).addPendingTask(0, false)
   }
 
   test("SPARK-21563 context's added jars shouldn't change mid-TaskSet") {
@@ -1595,4 +1609,50 @@ class TaskSetManagerSuite extends SparkFunSuite with 
LocalSparkContext with Logg
     verify(sched.dagScheduler).taskEnded(manager.tasks(3), Success, 
result.value(),
       result.accumUpdates, info3)
   }
+
+  test("SPARK-13704 Rack Resolution is done with a batch of de-duped hosts") {
+    val conf = new SparkConf()
+      .set(config.LOCALITY_WAIT, 0L)
+      .set(config.LOCALITY_WAIT_RACK, 1L)
+    sc = new SparkContext("local", "test", conf)
+    // Create a cluster with 20 racks, with hosts spread out among them
+    val execAndHost = (0 to 199).map { i =>
+      FakeRackUtil.assignHostToRack("host" + i, "rack" + (i % 20))
+      ("exec" + i, "host" + i)
+    }
+    sched = new FakeTaskScheduler(sc, execAndHost: _*)
+    // make a taskset with preferred locations on the first 100 hosts in our 
cluster
+    val locations = new ArrayBuffer[Seq[TaskLocation]]()
+    for (i <- 0 to 99) {
+      locations += Seq(TaskLocation("host" + i))
+    }
+    val taskSet = FakeTask.createTaskSet(100, locations: _*)
+    val clock = new ManualClock
+    // make sure we only do one rack resolution call, for the entire batch of 
hosts, as this
+    // can be expensive.  The FakeTaskScheduler calls rack resolution more 
than the real one
+    // -- that is outside of the scope of this test, we just want to check the 
task set manager.
+    FakeRackUtil.numBatchInvocation = 0
+    FakeRackUtil.numSingleHostInvocation = 0
+    val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock 
= clock)
+    assert(FakeRackUtil.numBatchInvocation === 1)
+    assert(FakeRackUtil.numSingleHostInvocation === 0)
+    // with rack locality, reject an offer on a host with an unknown rack
+    assert(manager.resourceOffer("otherExec", "otherHost", 
TaskLocality.RACK_LOCAL).isEmpty)
+    (0 until 20).foreach { rackIdx =>
+      (0 until 5).foreach { offerIdx =>
+        // if we offer hosts which are not in preferred locations,
+        // we'll reject them at NODE_LOCAL level,
+        // but accept them at RACK_LOCAL level if they're on OK racks
+        val hostIdx = 100 + rackIdx
+        assert(manager.resourceOffer("exec" + hostIdx, "host" + hostIdx, 
TaskLocality.NODE_LOCAL)
+          .isEmpty)
+        assert(manager.resourceOffer("exec" + hostIdx, "host" + hostIdx, 
TaskLocality.RACK_LOCAL)
+          .isDefined)
+      }
+    }
+    // check no more expensive calls to the rack resolution.  
manager.resourceOffer() will call
+    // the single-host resolution, but the real rack resolution would have 
cached all hosts
+    // by that point.
+    assert(FakeRackUtil.numBatchInvocation === 1)
+  }
 }
diff --git 
a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala
 
b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala
index 0a7a16f..2288bb5 100644
--- 
a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala
+++ 
b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala
@@ -138,9 +138,8 @@ private[yarn] class 
LocalityPreferredContainerPlacementStrategy(
         // Only filter out the ratio which is larger than 0, which means the 
current host can
         // still be allocated with new container request.
         val hosts = preferredLocalityRatio.filter(_._2 > 0).keys.toArray
-        val racks = hosts.map { h =>
-          resolver.resolve(yarnConf, h)
-        }.toSet
+        val racks = resolver.resolve(hosts).map(_.getNetworkLocation)
+          .filter(_ != null).toSet
         containerLocalityPreferences += ContainerLocalityPreferences(hosts, 
racks.toArray)
 
         // Minus 1 each time when the host is used. When the current ratio is 
0,
diff --git 
a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala
 
b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala
index c711d08..cab3272 100644
--- 
a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala
+++ 
b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala
@@ -17,24 +17,100 @@
 
 package org.apache.spark.deploy.yarn
 
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import com.google.common.base.Strings
 import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.CommonConfigurationKeysPublic
+import org.apache.hadoop.net._
+import org.apache.hadoop.util.ReflectionUtils
 import org.apache.hadoop.yarn.util.RackResolver
 import org.apache.log4j.{Level, Logger}
 
+import org.apache.spark.internal.Logging
+
 /**
- * Wrapper around YARN's [[RackResolver]]. This allows Spark tests to easily 
override the
- * default behavior, since YARN's class self-initializes the first time it's 
called, and
- * future calls all use the initial configuration.
+ * Re-implement YARN's [[RackResolver]] for hadoop releases without YARN-9332.
+ * This also allows Spark tests to easily override the default behavior, since 
YARN's class
+ * self-initializes the first time it's called, and future calls all use the 
initial configuration.
  */
-private[yarn] class SparkRackResolver {
+private[spark] class SparkRackResolver(conf: Configuration) extends Logging {
 
   // RackResolver logs an INFO message whenever it resolves a rack, which is 
way too often.
   if (Logger.getLogger(classOf[RackResolver]).getLevel == null) {
     Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN)
   }
 
-  def resolve(conf: Configuration, hostName: String): String = {
-    RackResolver.resolve(conf, hostName).getNetworkLocation()
+  private val dnsToSwitchMapping: DNSToSwitchMapping = {
+    val dnsToSwitchMappingClass =
+      
conf.getClass(CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY,
+        classOf[ScriptBasedMapping], classOf[DNSToSwitchMapping])
+    ReflectionUtils.newInstance(dnsToSwitchMappingClass, conf)
+        .asInstanceOf[DNSToSwitchMapping] match {
+      case c: CachedDNSToSwitchMapping => c
+      case o => new CachedDNSToSwitchMapping(o)
+    }
+  }
+
+  def resolve(hostName: String): String = {
+    coreResolve(Seq(hostName)).head.getNetworkLocation
+  }
+
+  /**
+   * Added in SPARK-13704.
+   * This should be changed to `RackResolver.resolve(conf, hostNames)`
+   * in hadoop releases with YARN-9332.
+   */
+  def resolve(hostNames: Seq[String]): Seq[Node] = {
+    coreResolve(hostNames)
+  }
+
+  private def coreResolve(hostNames: Seq[String]): Seq[Node] = {
+    val nodes = new ArrayBuffer[Node]
+    // dnsToSwitchMapping is thread-safe
+    val rNameList = dnsToSwitchMapping.resolve(hostNames.toList.asJava).asScala
+    if (rNameList == null || rNameList.isEmpty) {
+      hostNames.foreach(nodes += new NodeBase(_, NetworkTopology.DEFAULT_RACK))
+      logInfo(s"Got an error when resolving hostNames. " +
+        s"Falling back to ${NetworkTopology.DEFAULT_RACK} for all")
+    } else {
+      for ((hostName, rName) <- hostNames.zip(rNameList)) {
+        if (Strings.isNullOrEmpty(rName)) {
+          nodes += new NodeBase(hostName, NetworkTopology.DEFAULT_RACK)
+          logDebug(s"Could not resolve $hostName. " +
+            s"Falling back to ${NetworkTopology.DEFAULT_RACK}")
+        } else {
+          nodes += new NodeBase(hostName, rName)
+        }
+      }
+    }
+    nodes.toList
+  }
+}
+
+/**
+ * Utility to resolve the rack for hosts in an efficient manner.
+ * It will cache the rack for individual hosts to avoid
+ * repeatedly performing the same expensive lookup.
+ */
+object SparkRackResolver extends Logging {
+  @volatile private var instance: SparkRackResolver = _
+
+  /**
+   * It will return the static resolver instance.  If there is already an 
instance, the passed
+   * conf is entirely ignored.  If there is not a shared instance, it will 
create one with the
+   * given conf.
+   */
+  def get(conf: Configuration): SparkRackResolver = {
+    if (instance == null) {
+      synchronized {
+        if (instance == null) {
+          instance = new SparkRackResolver(conf)
+        }
+      }
+    }
+    instance
   }
 
 }
diff --git 
a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
 
b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 5d939cf..1dc9d49 100644
--- 
a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ 
b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -432,7 +432,7 @@ private[yarn] class YarnAllocator(
         override def run(): Unit = {
           try {
             for (allocatedContainer <- remainingAfterHostMatches) {
-              val rack = resolver.resolve(conf, 
allocatedContainer.getNodeId.getHost)
+              val rack = resolver.resolve(allocatedContainer.getNodeId.getHost)
               matchContainerToRequest(allocatedContainer, rack, 
containersToUse,
                 remainingAfterRackMatches)
             }
diff --git 
a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
 
b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
index cf16edf..7c67493 100644
--- 
a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
+++ 
b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
@@ -83,7 +83,7 @@ private[spark] class YarnRMClient extends Logging {
       localResources: Map[String, LocalResource]): YarnAllocator = {
     require(registered, "Must register AM before creating allocator.")
     new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, 
appAttemptId, securityMgr,
-      localResources, new SparkRackResolver())
+      localResources, SparkRackResolver.get(conf))
   }
 
   /**
diff --git 
a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala
 
b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala
index 0293821..d466ed7 100644
--- 
a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala
+++ 
b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala
@@ -17,23 +17,23 @@
 
 package org.apache.spark.scheduler.cluster
 
-import org.apache.hadoop.yarn.util.RackResolver
-import org.apache.log4j.{Level, Logger}
+import org.apache.hadoop.net.NetworkTopology
 
 import org.apache.spark._
+import org.apache.spark.deploy.yarn.SparkRackResolver
 import org.apache.spark.scheduler.TaskSchedulerImpl
 import org.apache.spark.util.Utils
 
 private[spark] class YarnScheduler(sc: SparkContext) extends 
TaskSchedulerImpl(sc) {
 
-  // RackResolver logs an INFO message whenever it resolves a rack, which is 
way too often.
-  if (Logger.getLogger(classOf[RackResolver]).getLevel == null) {
-    Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN)
-  }
+  override val defaultRackValue: Option[String] = 
Some(NetworkTopology.DEFAULT_RACK)
+
+  private[spark] val resolver = SparkRackResolver.get(sc.hadoopConfiguration)
 
-  // By default, rack is unknown
-  override def getRackForHost(hostPort: String): Option[String] = {
-    val host = Utils.parseHostPort(hostPort)._1
-    Option(RackResolver.resolve(sc.hadoopConfiguration, 
host).getNetworkLocation)
+  override def getRacksForHosts(hostPorts: Seq[String]): Seq[Option[String]] = 
{
+    val hosts = hostPorts.map(Utils.parseHostPort(_)._1)
+    resolver.resolve(hosts).map { node =>
+      Option(node.getNetworkLocation)
+    }
   }
 }
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
index 42b5966..59291af 100644
--- 
a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
+++ 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -31,6 +31,7 @@ import org.mockito.Mockito._
 import org.scalatest.{BeforeAndAfterEach, Matchers}
 
 import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
 import org.apache.spark.deploy.yarn.config._
 import org.apache.spark.internal.config._
@@ -38,9 +39,9 @@ import org.apache.spark.rpc.RpcEndpointRef
 import org.apache.spark.scheduler.SplitInfo
 import org.apache.spark.util.ManualClock
 
-class MockResolver extends SparkRackResolver {
+class MockResolver extends SparkRackResolver(SparkHadoopUtil.get.conf) {
 
-  override def resolve(conf: Configuration, hostName: String): String = {
+  override def resolve(hostName: String): String = {
     if (hostName == "host3") "/rack2" else "/rack1"
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to