This is an automated email from the ASF dual-hosted git repository.
irashid pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2f0a38c [SPARK-29398][CORE] Support dedicated thread pools for RPC
endpoints
2f0a38c is described below
commit 2f0a38cb50e3e8b4b72219c7b2b8b15d51f6b931
Author: Marcelo Vanzin <[email protected]>
AuthorDate: Thu Oct 17 13:14:32 2019 -0500
[SPARK-29398][CORE] Support dedicated thread pools for RPC endpoints
The current RPC backend in Spark supports single- and multi-threaded
message delivery to endpoints, but they all share the same underlying
thread pool. So an RPC endpoint that blocks a dispatcher thread can
negatively affect other endpoints.
This can be more pronounced with configurations that limit the number
of RPC dispatch threads based on configuration and / or running
environment. And exposing the RPC layer to other code (for example
with something like SPARK-29396) could make it easy to affect normal
Spark operation with a badly written RPC handler.
This change adds a new RPC endpoint type that tells the RPC env to
create dedicated dispatch threads, so that those effects are minimised.
Other endpoints will still need CPU to process their messages, but
they won't be able to actively block the dispatch thread of these
isolated endpoints.
As part of the change, I've changed the most important Spark endpoints
(the driver, executor and block manager endpoints) to be isolated from
others. This means a couple of extra threads are created on the driver
and executor for these endpoints.
Tested with existing unit tests, which hammer the RPC system extensively,
and also by running applications on a cluster (with a prototype of
SPARK-29396).
Closes #26059 from vanzin/SPARK-29398.
Authored-by: Marcelo Vanzin <[email protected]>
Signed-off-by: Imran Rashid <[email protected]>
---
.../executor/CoarseGrainedExecutorBackend.scala | 2 +-
.../scala/org/apache/spark/rpc/RpcEndpoint.scala | 16 ++
.../org/apache/spark/rpc/netty/Dispatcher.scala | 130 ++++----------
.../scala/org/apache/spark/rpc/netty/Inbox.scala | 6 +-
.../org/apache/spark/rpc/netty/MessageLoop.scala | 194 +++++++++++++++++++++
.../cluster/CoarseGrainedSchedulerBackend.scala | 2 +-
.../spark/storage/BlockManagerMasterEndpoint.scala | 4 +-
.../spark/storage/BlockManagerSlaveEndpoint.scala | 4 +-
.../scala/org/apache/spark/rpc/RpcEnvSuite.scala | 35 +++-
.../org/apache/spark/rpc/netty/InboxSuite.scala | 23 +--
10 files changed, 296 insertions(+), 120 deletions(-)
diff --git
a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index fbf2dc7..b4bca1e 100644
---
a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++
b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -51,7 +51,7 @@ private[spark] class CoarseGrainedExecutorBackend(
userClassPath: Seq[URL],
env: SparkEnv,
resourcesFileOpt: Option[String])
- extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
+ extends IsolatedRpcEndpoint with ExecutorBackend with Logging {
private implicit val formats = DefaultFormats
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
index 97eed54..4728759 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
@@ -146,3 +146,19 @@ private[spark] trait RpcEndpoint {
* [[ThreadSafeRpcEndpoint]] for different messages.
*/
private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint
+
+/**
+ * An endpoint that uses a dedicated thread pool for delivering messages.
+ */
+private[spark] trait IsolatedRpcEndpoint extends RpcEndpoint {
+
+ /**
+ * How many threads to use for delivering messages. By default, use a single
thread.
+ *
+ * Note that requesting more than one thread means that the endpoint should
be able to handle
+ * messages arriving from many threads at once, and all the things that
entails (including
+ * messages being delivered to the endpoint out of order).
+ */
+ def threadCount(): Int = 1
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index 2f923d7..27c943d 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -17,20 +17,16 @@
package org.apache.spark.rpc.netty
-import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap,
LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, CountDownLatch}
import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConverters._
import scala.concurrent.Promise
-import scala.util.control.NonFatal
-import org.apache.spark.{SparkConf, SparkContext, SparkException}
+import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
-import org.apache.spark.internal.config.EXECUTOR_ID
-import
org.apache.spark.internal.config.Network.RPC_NETTY_DISPATCHER_NUM_THREADS
import org.apache.spark.network.client.RpcResponseCallback
import org.apache.spark.rpc._
-import org.apache.spark.util.ThreadUtils
/**
* A message dispatcher, responsible for routing RPC messages to the
appropriate endpoint(s).
@@ -40,20 +36,23 @@ import org.apache.spark.util.ThreadUtils
*/
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int)
extends Logging {
- private class EndpointData(
- val name: String,
- val endpoint: RpcEndpoint,
- val ref: NettyRpcEndpointRef) {
- val inbox = new Inbox(ref, endpoint)
- }
-
- private val endpoints: ConcurrentMap[String, EndpointData] =
- new ConcurrentHashMap[String, EndpointData]
+ private val endpoints: ConcurrentMap[String, MessageLoop] =
+ new ConcurrentHashMap[String, MessageLoop]
private val endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] =
new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
- // Track the receivers whose inboxes may contain messages.
- private val receivers = new LinkedBlockingQueue[EndpointData]
+ private val shutdownLatch = new CountDownLatch(1)
+ private lazy val sharedLoop = new SharedMessageLoop(nettyEnv.conf, this,
numUsableCores)
+
+ private def getMessageLoop(name: String, endpoint: RpcEndpoint): MessageLoop
= {
+ endpoint match {
+ case e: IsolatedRpcEndpoint =>
+ new DedicatedMessageLoop(name, e, this)
+ case _ =>
+ sharedLoop.register(name, endpoint)
+ sharedLoop
+ }
+ }
/**
* True if the dispatcher has been stopped. Once stopped, all messages
posted will be bounced
@@ -69,13 +68,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv,
numUsableCores: Int) exte
if (stopped) {
throw new IllegalStateException("RpcEnv has been stopped")
}
- if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint,
endpointRef)) != null) {
+ if (endpoints.putIfAbsent(name, getMessageLoop(name, endpoint)) != null)
{
throw new IllegalArgumentException(s"There is already an RpcEndpoint
called $name")
}
- val data = endpoints.get(name)
- endpointRefs.put(data.endpoint, data.ref)
- receivers.offer(data) // for the OnStart message
}
+ endpointRefs.put(endpoint, endpointRef)
endpointRef
}
@@ -85,10 +82,9 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv,
numUsableCores: Int) exte
// Should be idempotent
private def unregisterRpcEndpoint(name: String): Unit = {
- val data = endpoints.remove(name)
- if (data != null) {
- data.inbox.stop()
- receivers.offer(data) // for the OnStop message
+ val loop = endpoints.remove(name)
+ if (loop != null) {
+ loop.unregister(name)
}
// Don't clean `endpointRefs` here because it's possible that some
messages are being processed
// now and they can use `getRpcEndpointRef`. So `endpointRefs` will be
cleaned in Inbox via
@@ -155,14 +151,13 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv,
numUsableCores: Int) exte
message: InboxMessage,
callbackIfStopped: (Exception) => Unit): Unit = {
val error = synchronized {
- val data = endpoints.get(endpointName)
+ val loop = endpoints.get(endpointName)
if (stopped) {
Some(new RpcEnvStoppedException())
- } else if (data == null) {
+ } else if (loop == null) {
Some(new SparkException(s"Could not find $endpointName."))
} else {
- data.inbox.post(message)
- receivers.offer(data)
+ loop.post(endpointName, message)
None
}
}
@@ -177,15 +172,23 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv,
numUsableCores: Int) exte
}
stopped = true
}
- // Stop all endpoints. This will queue all endpoints for processing by the
message loops.
- endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)
- // Enqueue a message that tells the message loops to stop.
- receivers.offer(PoisonPill)
- threadpool.shutdown()
+ var stopSharedLoop = false
+ endpoints.asScala.foreach { case (name, loop) =>
+ unregisterRpcEndpoint(name)
+ if (!loop.isInstanceOf[SharedMessageLoop]) {
+ loop.stop()
+ } else {
+ stopSharedLoop = true
+ }
+ }
+ if (stopSharedLoop) {
+ sharedLoop.stop()
+ }
+ shutdownLatch.countDown()
}
def awaitTermination(): Unit = {
- threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
+ shutdownLatch.await()
}
/**
@@ -194,61 +197,4 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv,
numUsableCores: Int) exte
def verify(name: String): Boolean = {
endpoints.containsKey(name)
}
-
- private def getNumOfThreads(conf: SparkConf): Int = {
- val availableCores =
- if (numUsableCores > 0) numUsableCores else
Runtime.getRuntime.availableProcessors()
-
- val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS)
- .getOrElse(math.max(2, availableCores))
-
- conf.get(EXECUTOR_ID).map { id =>
- val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else
"executor"
- conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads",
modNumThreads)
- }.getOrElse(modNumThreads)
- }
-
- /** Thread pool used for dispatching messages. */
- private val threadpool: ThreadPoolExecutor = {
- val numThreads = getNumOfThreads(nettyEnv.conf)
- val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads,
"dispatcher-event-loop")
- for (i <- 0 until numThreads) {
- pool.execute(new MessageLoop)
- }
- pool
- }
-
- /** Message loop used for dispatching messages. */
- private class MessageLoop extends Runnable {
- override def run(): Unit = {
- try {
- while (true) {
- try {
- val data = receivers.take()
- if (data == PoisonPill) {
- // Put PoisonPill back so that other MessageLoops can see it.
- receivers.offer(PoisonPill)
- return
- }
- data.inbox.process(Dispatcher.this)
- } catch {
- case NonFatal(e) => logError(e.getMessage, e)
- }
- }
- } catch {
- case _: InterruptedException => // exit
- case t: Throwable =>
- try {
- // Re-submit a MessageLoop so that Dispatcher will still work if
- // UncaughtExceptionHandler decides to not kill JVM.
- threadpool.execute(new MessageLoop)
- } finally {
- throw t
- }
- }
- }
- }
-
- /** A poison endpoint that indicates MessageLoop should exit its message
loop. */
- private val PoisonPill = new EndpointData(null, null, null)
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
index 44d2622..2ed03f7 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
@@ -54,9 +54,7 @@ private[netty] case class RemoteProcessConnectionError(cause:
Throwable, remoteA
/**
* An inbox that stores messages for an [[RpcEndpoint]] and posts messages to
it thread-safely.
*/
-private[netty] class Inbox(
- val endpointRef: NettyRpcEndpointRef,
- val endpoint: RpcEndpoint)
+private[netty] class Inbox(val endpointName: String, val endpoint: RpcEndpoint)
extends Logging {
inbox => // Give this an alias so we can use it more clearly in closures.
@@ -195,7 +193,7 @@ private[netty] class Inbox(
* Exposed for testing.
*/
protected def onDrop(message: InboxMessage): Unit = {
- logWarning(s"Drop $message because $endpointRef is stopped")
+ logWarning(s"Drop $message because endpoint $endpointName is stopped")
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala
b/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala
new file mode 100644
index 0000000..c985c72
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import java.util.concurrent._
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.EXECUTOR_ID
+import org.apache.spark.internal.config.Network._
+import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcEndpoint}
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * A message loop used by [[Dispatcher]] to deliver messages to endpoints.
+ */
+private sealed abstract class MessageLoop(dispatcher: Dispatcher) extends
Logging {
+
+ // List of inboxes with pending messages, to be processed by the message
loop.
+ private val active = new LinkedBlockingQueue[Inbox]()
+
+ // Message loop task; should be run in all threads of the message loop's
pool.
+ protected val receiveLoopRunnable = new Runnable() {
+ override def run(): Unit = receiveLoop()
+ }
+
+ protected val threadpool: ExecutorService
+
+ private var stopped = false
+
+ def post(endpointName: String, message: InboxMessage): Unit
+
+ def unregister(name: String): Unit
+
+ def stop(): Unit = {
+ synchronized {
+ if (!stopped) {
+ setActive(MessageLoop.PoisonPill)
+ threadpool.shutdown()
+ stopped = true
+ }
+ }
+ threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
+ }
+
+ protected final def setActive(inbox: Inbox): Unit = active.offer(inbox)
+
+ private def receiveLoop(): Unit = {
+ try {
+ while (true) {
+ try {
+ val inbox = active.take()
+ if (inbox == MessageLoop.PoisonPill) {
+ // Put PoisonPill back so that other threads can see it.
+ setActive(MessageLoop.PoisonPill)
+ return
+ }
+ inbox.process(dispatcher)
+ } catch {
+ case NonFatal(e) => logError(e.getMessage, e)
+ }
+ }
+ } catch {
+ case _: InterruptedException => // exit
+ case t: Throwable =>
+ try {
+ // Re-submit a receive task so that message delivery will still
work if
+ // UncaughtExceptionHandler decides to not kill JVM.
+ threadpool.execute(receiveLoopRunnable)
+ } finally {
+ throw t
+ }
+ }
+ }
+}
+
+private object MessageLoop {
+ /** A poison inbox that indicates the message loop should stop processing
messages. */
+ val PoisonPill = new Inbox(null, null)
+}
+
+/**
+ * A message loop that serves multiple RPC endpoints, using a shared thread
pool.
+ */
+private class SharedMessageLoop(
+ conf: SparkConf,
+ dispatcher: Dispatcher,
+ numUsableCores: Int)
+ extends MessageLoop(dispatcher) {
+
+ private val endpoints = new ConcurrentHashMap[String, Inbox]()
+
+ private def getNumOfThreads(conf: SparkConf): Int = {
+ val availableCores =
+ if (numUsableCores > 0) numUsableCores else
Runtime.getRuntime.availableProcessors()
+
+ val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS)
+ .getOrElse(math.max(2, availableCores))
+
+ conf.get(EXECUTOR_ID).map { id =>
+ val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else
"executor"
+ conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads",
modNumThreads)
+ }.getOrElse(modNumThreads)
+ }
+
+ /** Thread pool used for dispatching messages. */
+ override protected val threadpool: ThreadPoolExecutor = {
+ val numThreads = getNumOfThreads(conf)
+ val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads,
"dispatcher-event-loop")
+ for (i <- 0 until numThreads) {
+ pool.execute(receiveLoopRunnable)
+ }
+ pool
+ }
+
+ override def post(endpointName: String, message: InboxMessage): Unit = {
+ val inbox = endpoints.get(endpointName)
+ inbox.post(message)
+ setActive(inbox)
+ }
+
+ override def unregister(name: String): Unit = {
+ val inbox = endpoints.remove(name)
+ if (inbox != null) {
+ inbox.stop()
+ // Mark active to handle the OnStop message.
+ setActive(inbox)
+ }
+ }
+
+ def register(name: String, endpoint: RpcEndpoint): Unit = {
+ val inbox = new Inbox(name, endpoint)
+ endpoints.put(name, inbox)
+ // Mark active to handle the OnStart message.
+ setActive(inbox)
+ }
+}
+
+/**
+ * A message loop that is dedicated to a single RPC endpoint.
+ */
+private class DedicatedMessageLoop(
+ name: String,
+ endpoint: IsolatedRpcEndpoint,
+ dispatcher: Dispatcher)
+ extends MessageLoop(dispatcher) {
+
+ private val inbox = new Inbox(name, endpoint)
+
+ override protected val threadpool = if (endpoint.threadCount() > 1) {
+ ThreadUtils.newDaemonCachedThreadPool(s"dispatcher-$name",
endpoint.threadCount())
+ } else {
+ ThreadUtils.newDaemonSingleThreadExecutor(s"dispatcher-$name")
+ }
+
+ (1 to endpoint.threadCount()).foreach { _ =>
+ threadpool.submit(receiveLoopRunnable)
+ }
+
+ // Mark active to handle the OnStart message.
+ setActive(inbox)
+
+ override def post(endpointName: String, message: InboxMessage): Unit = {
+ require(endpointName == name)
+ inbox.post(message)
+ setActive(inbox)
+ }
+
+ override def unregister(endpointName: String): Unit = synchronized {
+ require(endpointName == name)
+ inbox.stop()
+ // Mark active to handle the OnStop message.
+ setActive(inbox)
+ setActive(MessageLoop.PoisonPill)
+ threadpool.shutdown()
+ }
+}
diff --git
a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 4958389..6e990d1 100644
---
a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++
b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -111,7 +111,7 @@ class CoarseGrainedSchedulerBackend(scheduler:
TaskSchedulerImpl, val rpcEnv: Rp
private val reviveThread =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread")
- class DriverEndpoint extends ThreadSafeRpcEndpoint with Logging {
+ class DriverEndpoint extends IsolatedRpcEndpoint with Logging {
override val rpcEnv: RpcEnv = CoarseGrainedSchedulerBackend.this.rpcEnv
diff --git
a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index faf6f71..02d0e1a 100644
---
a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++
b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -30,7 +30,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.network.shuffle.ExternalBlockStoreClient
-import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv,
ThreadSafeRpcEndpoint}
+import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext,
RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler._
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils}
@@ -46,7 +46,7 @@ class BlockManagerMasterEndpoint(
conf: SparkConf,
listenerBus: LiveListenerBus,
externalBlockStoreClient: Option[ExternalBlockStoreClient])
- extends ThreadSafeRpcEndpoint with Logging {
+ extends IsolatedRpcEndpoint with Logging {
// Mapping from block manager id to the block manager's information.
private val blockManagerInfo = new mutable.HashMap[BlockManagerId,
BlockManagerInfo]
diff --git
a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
index f90595a..29e2114 100644
---
a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
+++
b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
@@ -21,7 +21,7 @@ import scala.concurrent.{ExecutionContext, Future}
import org.apache.spark.{MapOutputTracker, SparkEnv}
import org.apache.spark.internal.Logging
-import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
+import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEnv}
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{ThreadUtils, Utils}
@@ -34,7 +34,7 @@ class BlockManagerSlaveEndpoint(
override val rpcEnv: RpcEnv,
blockManager: BlockManager,
mapOutputTracker: MapOutputTracker)
- extends ThreadSafeRpcEndpoint with Logging {
+ extends IsolatedRpcEndpoint with Logging {
private val asyncThreadPool =
ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool",
100)
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 5929fbf..c10f2c2 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -36,7 +36,6 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException,
SparkFunSuite}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.config._
-import org.apache.spark.internal.config.Network
import org.apache.spark.util.{ThreadUtils, Utils}
/**
@@ -954,6 +953,40 @@ abstract class RpcEnvSuite extends SparkFunSuite with
BeforeAndAfterAll {
verify(endpoint, never()).onDisconnected(any())
verify(endpoint, never()).onNetworkError(any(), any())
}
+
+ test("isolated endpoints") {
+ val latch = new CountDownLatch(1)
+ val singleThreadedEnv = createRpcEnv(
+ new SparkConf().set(Network.RPC_NETTY_DISPATCHER_NUM_THREADS, 1),
"singleThread", 0)
+ try {
+ val blockingEndpoint = singleThreadedEnv.setupEndpoint("blocking", new
IsolatedRpcEndpoint {
+ override val rpcEnv: RpcEnv = singleThreadedEnv
+
+ override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
+ case m =>
+ latch.await()
+ context.reply(m)
+ }
+ })
+
+ val nonBlockingEndpoint =
singleThreadedEnv.setupEndpoint("non-blocking", new RpcEndpoint {
+ override val rpcEnv: RpcEnv = singleThreadedEnv
+
+ override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
+ case m => context.reply(m)
+ }
+ })
+
+ val to = new RpcTimeout(5.seconds, "test-timeout")
+ val blockingFuture = blockingEndpoint.ask[String]("hi", to)
+ assert(nonBlockingEndpoint.askSync[String]("hello", to) === "hello")
+ latch.countDown()
+ assert(ThreadUtils.awaitResult(blockingFuture, 5.seconds) === "hi")
+ } finally {
+ latch.countDown()
+ singleThreadedEnv.shutdown()
+ }
+ }
}
class UnserializableClass
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
index e553956..c74c728 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
@@ -29,12 +29,9 @@ class InboxSuite extends SparkFunSuite {
test("post") {
val endpoint = new TestRpcEndpoint
- val endpointRef = mock(classOf[NettyRpcEndpointRef])
- when(endpointRef.name).thenReturn("hello")
-
val dispatcher = mock(classOf[Dispatcher])
- val inbox = new Inbox(endpointRef, endpoint)
+ val inbox = new Inbox("name", endpoint)
val message = OneWayMessage(null, "hi")
inbox.post(message)
inbox.process(dispatcher)
@@ -51,10 +48,9 @@ class InboxSuite extends SparkFunSuite {
test("post: with reply") {
val endpoint = new TestRpcEndpoint
- val endpointRef = mock(classOf[NettyRpcEndpointRef])
val dispatcher = mock(classOf[Dispatcher])
- val inbox = new Inbox(endpointRef, endpoint)
+ val inbox = new Inbox("name", endpoint)
val message = RpcMessage(null, "hi", null)
inbox.post(message)
inbox.process(dispatcher)
@@ -65,13 +61,10 @@ class InboxSuite extends SparkFunSuite {
test("post: multiple threads") {
val endpoint = new TestRpcEndpoint
- val endpointRef = mock(classOf[NettyRpcEndpointRef])
- when(endpointRef.name).thenReturn("hello")
-
val dispatcher = mock(classOf[Dispatcher])
val numDroppedMessages = new AtomicInteger(0)
- val inbox = new Inbox(endpointRef, endpoint) {
+ val inbox = new Inbox("name", endpoint) {
override def onDrop(message: InboxMessage): Unit = {
numDroppedMessages.incrementAndGet()
}
@@ -107,12 +100,10 @@ class InboxSuite extends SparkFunSuite {
test("post: Associated") {
val endpoint = new TestRpcEndpoint
- val endpointRef = mock(classOf[NettyRpcEndpointRef])
val dispatcher = mock(classOf[Dispatcher])
-
val remoteAddress = RpcAddress("localhost", 11111)
- val inbox = new Inbox(endpointRef, endpoint)
+ val inbox = new Inbox("name", endpoint)
inbox.post(RemoteProcessConnected(remoteAddress))
inbox.process(dispatcher)
@@ -121,12 +112,11 @@ class InboxSuite extends SparkFunSuite {
test("post: Disassociated") {
val endpoint = new TestRpcEndpoint
- val endpointRef = mock(classOf[NettyRpcEndpointRef])
val dispatcher = mock(classOf[Dispatcher])
val remoteAddress = RpcAddress("localhost", 11111)
- val inbox = new Inbox(endpointRef, endpoint)
+ val inbox = new Inbox("name", endpoint)
inbox.post(RemoteProcessDisconnected(remoteAddress))
inbox.process(dispatcher)
@@ -135,13 +125,12 @@ class InboxSuite extends SparkFunSuite {
test("post: AssociationError") {
val endpoint = new TestRpcEndpoint
- val endpointRef = mock(classOf[NettyRpcEndpointRef])
val dispatcher = mock(classOf[Dispatcher])
val remoteAddress = RpcAddress("localhost", 11111)
val cause = new RuntimeException("Oops")
- val inbox = new Inbox(endpointRef, endpoint)
+ val inbox = new Inbox("name", endpoint)
inbox.post(RemoteProcessConnectionError(cause, remoteAddress))
inbox.process(dispatcher)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]