This is an automated email from the ASF dual-hosted git repository.

irashid pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2f0a38c  [SPARK-29398][CORE] Support dedicated thread pools for RPC 
endpoints
2f0a38c is described below

commit 2f0a38cb50e3e8b4b72219c7b2b8b15d51f6b931
Author: Marcelo Vanzin <[email protected]>
AuthorDate: Thu Oct 17 13:14:32 2019 -0500

    [SPARK-29398][CORE] Support dedicated thread pools for RPC endpoints
    
    The current RPC backend in Spark supports single- and multi-threaded
    message delivery to endpoints, but they all share the same underlying
    thread pool. So an RPC endpoint that blocks a dispatcher thread can
    negatively affect other endpoints.
    
    This can be more pronounced with configurations that limit the number
    of RPC dispatch threads based on configuration and / or running
    environment. And exposing the RPC layer to other code (for example
    with something like SPARK-29396) could make it easy to affect normal
    Spark operation with a badly written RPC handler.
    
    This change adds a new RPC endpoint type that tells the RPC env to
    create dedicated dispatch threads, so that those effects are minimised.
    Other endpoints will still need CPU to process their messages, but
    they won't be able to actively block the dispatch thread of these
    isolated endpoints.
    
    As part of the change, I've changed the most important Spark endpoints
    (the driver, executor and block manager endpoints) to be isolated from
    others. This means a couple of extra threads are created on the driver
    and executor for these endpoints.
    
    Tested with existing unit tests, which hammer the RPC system extensively,
    and also by running applications on a cluster (with a prototype of
    SPARK-29396).
    
    Closes #26059 from vanzin/SPARK-29398.
    
    Authored-by: Marcelo Vanzin <[email protected]>
    Signed-off-by: Imran Rashid <[email protected]>
---
 .../executor/CoarseGrainedExecutorBackend.scala    |   2 +-
 .../scala/org/apache/spark/rpc/RpcEndpoint.scala   |  16 ++
 .../org/apache/spark/rpc/netty/Dispatcher.scala    | 130 ++++----------
 .../scala/org/apache/spark/rpc/netty/Inbox.scala   |   6 +-
 .../org/apache/spark/rpc/netty/MessageLoop.scala   | 194 +++++++++++++++++++++
 .../cluster/CoarseGrainedSchedulerBackend.scala    |   2 +-
 .../spark/storage/BlockManagerMasterEndpoint.scala |   4 +-
 .../spark/storage/BlockManagerSlaveEndpoint.scala  |   4 +-
 .../scala/org/apache/spark/rpc/RpcEnvSuite.scala   |  35 +++-
 .../org/apache/spark/rpc/netty/InboxSuite.scala    |  23 +--
 10 files changed, 296 insertions(+), 120 deletions(-)

diff --git 
a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
 
b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index fbf2dc7..b4bca1e 100644
--- 
a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ 
b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -51,7 +51,7 @@ private[spark] class CoarseGrainedExecutorBackend(
     userClassPath: Seq[URL],
     env: SparkEnv,
     resourcesFileOpt: Option[String])
-  extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
+  extends IsolatedRpcEndpoint with ExecutorBackend with Logging {
 
   private implicit val formats = DefaultFormats
 
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala 
b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
index 97eed54..4728759 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
@@ -146,3 +146,19 @@ private[spark] trait RpcEndpoint {
  * [[ThreadSafeRpcEndpoint]] for different messages.
  */
 private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint
+
+/**
+ * An endpoint that uses a dedicated thread pool for delivering messages.
+ */
+private[spark] trait IsolatedRpcEndpoint extends RpcEndpoint {
+
+  /**
+   * How many threads to use for delivering messages. By default, use a single 
thread.
+   *
+   * Note that requesting more than one thread means that the endpoint should 
be able to handle
+   * messages arriving from many threads at once, and all the things that 
entails (including
+   * messages being delivered to the endpoint out of order).
+   */
+  def threadCount(): Int = 1
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index 2f923d7..27c943d 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -17,20 +17,16 @@
 
 package org.apache.spark.rpc.netty
 
-import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, 
LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, CountDownLatch}
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.JavaConverters._
 import scala.concurrent.Promise
-import scala.util.control.NonFatal
 
-import org.apache.spark.{SparkConf, SparkContext, SparkException}
+import org.apache.spark.SparkException
 import org.apache.spark.internal.Logging
-import org.apache.spark.internal.config.EXECUTOR_ID
-import 
org.apache.spark.internal.config.Network.RPC_NETTY_DISPATCHER_NUM_THREADS
 import org.apache.spark.network.client.RpcResponseCallback
 import org.apache.spark.rpc._
-import org.apache.spark.util.ThreadUtils
 
 /**
  * A message dispatcher, responsible for routing RPC messages to the 
appropriate endpoint(s).
@@ -40,20 +36,23 @@ import org.apache.spark.util.ThreadUtils
  */
 private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) 
extends Logging {
 
-  private class EndpointData(
-      val name: String,
-      val endpoint: RpcEndpoint,
-      val ref: NettyRpcEndpointRef) {
-    val inbox = new Inbox(ref, endpoint)
-  }
-
-  private val endpoints: ConcurrentMap[String, EndpointData] =
-    new ConcurrentHashMap[String, EndpointData]
+  private val endpoints: ConcurrentMap[String, MessageLoop] =
+    new ConcurrentHashMap[String, MessageLoop]
   private val endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] =
     new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
 
-  // Track the receivers whose inboxes may contain messages.
-  private val receivers = new LinkedBlockingQueue[EndpointData]
+  private val shutdownLatch = new CountDownLatch(1)
+  private lazy val sharedLoop = new SharedMessageLoop(nettyEnv.conf, this, 
numUsableCores)
+
+  private def getMessageLoop(name: String, endpoint: RpcEndpoint): MessageLoop 
= {
+    endpoint match {
+      case e: IsolatedRpcEndpoint =>
+        new DedicatedMessageLoop(name, e, this)
+      case _ =>
+        sharedLoop.register(name, endpoint)
+        sharedLoop
+    }
+  }
 
   /**
    * True if the dispatcher has been stopped. Once stopped, all messages 
posted will be bounced
@@ -69,13 +68,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, 
numUsableCores: Int) exte
       if (stopped) {
         throw new IllegalStateException("RpcEnv has been stopped")
       }
-      if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, 
endpointRef)) != null) {
+      if (endpoints.putIfAbsent(name, getMessageLoop(name, endpoint)) != null) 
{
         throw new IllegalArgumentException(s"There is already an RpcEndpoint 
called $name")
       }
-      val data = endpoints.get(name)
-      endpointRefs.put(data.endpoint, data.ref)
-      receivers.offer(data)  // for the OnStart message
     }
+    endpointRefs.put(endpoint, endpointRef)
     endpointRef
   }
 
@@ -85,10 +82,9 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, 
numUsableCores: Int) exte
 
   // Should be idempotent
   private def unregisterRpcEndpoint(name: String): Unit = {
-    val data = endpoints.remove(name)
-    if (data != null) {
-      data.inbox.stop()
-      receivers.offer(data)  // for the OnStop message
+    val loop = endpoints.remove(name)
+    if (loop != null) {
+      loop.unregister(name)
     }
     // Don't clean `endpointRefs` here because it's possible that some 
messages are being processed
     // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be 
cleaned in Inbox via
@@ -155,14 +151,13 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, 
numUsableCores: Int) exte
       message: InboxMessage,
       callbackIfStopped: (Exception) => Unit): Unit = {
     val error = synchronized {
-      val data = endpoints.get(endpointName)
+      val loop = endpoints.get(endpointName)
       if (stopped) {
         Some(new RpcEnvStoppedException())
-      } else if (data == null) {
+      } else if (loop == null) {
         Some(new SparkException(s"Could not find $endpointName."))
       } else {
-        data.inbox.post(message)
-        receivers.offer(data)
+        loop.post(endpointName, message)
         None
       }
     }
@@ -177,15 +172,23 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, 
numUsableCores: Int) exte
       }
       stopped = true
     }
-    // Stop all endpoints. This will queue all endpoints for processing by the 
message loops.
-    endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)
-    // Enqueue a message that tells the message loops to stop.
-    receivers.offer(PoisonPill)
-    threadpool.shutdown()
+    var stopSharedLoop = false
+    endpoints.asScala.foreach { case (name, loop) =>
+      unregisterRpcEndpoint(name)
+      if (!loop.isInstanceOf[SharedMessageLoop]) {
+        loop.stop()
+      } else {
+        stopSharedLoop = true
+      }
+    }
+    if (stopSharedLoop) {
+      sharedLoop.stop()
+    }
+    shutdownLatch.countDown()
   }
 
   def awaitTermination(): Unit = {
-    threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
+    shutdownLatch.await()
   }
 
   /**
@@ -194,61 +197,4 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, 
numUsableCores: Int) exte
   def verify(name: String): Boolean = {
     endpoints.containsKey(name)
   }
-
-  private def getNumOfThreads(conf: SparkConf): Int = {
-    val availableCores =
-      if (numUsableCores > 0) numUsableCores else 
Runtime.getRuntime.availableProcessors()
-
-    val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS)
-      .getOrElse(math.max(2, availableCores))
-
-    conf.get(EXECUTOR_ID).map { id =>
-      val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else 
"executor"
-      conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", 
modNumThreads)
-    }.getOrElse(modNumThreads)
-  }
-
-  /** Thread pool used for dispatching messages. */
-  private val threadpool: ThreadPoolExecutor = {
-    val numThreads = getNumOfThreads(nettyEnv.conf)
-    val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, 
"dispatcher-event-loop")
-    for (i <- 0 until numThreads) {
-      pool.execute(new MessageLoop)
-    }
-    pool
-  }
-
-  /** Message loop used for dispatching messages. */
-  private class MessageLoop extends Runnable {
-    override def run(): Unit = {
-      try {
-        while (true) {
-          try {
-            val data = receivers.take()
-            if (data == PoisonPill) {
-              // Put PoisonPill back so that other MessageLoops can see it.
-              receivers.offer(PoisonPill)
-              return
-            }
-            data.inbox.process(Dispatcher.this)
-          } catch {
-            case NonFatal(e) => logError(e.getMessage, e)
-          }
-        }
-      } catch {
-        case _: InterruptedException => // exit
-        case t: Throwable =>
-          try {
-            // Re-submit a MessageLoop so that Dispatcher will still work if
-            // UncaughtExceptionHandler decides to not kill JVM.
-            threadpool.execute(new MessageLoop)
-          } finally {
-            throw t
-          }
-      }
-    }
-  }
-
-  /** A poison endpoint that indicates MessageLoop should exit its message 
loop. */
-  private val PoisonPill = new EndpointData(null, null, null)
 }
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
index 44d2622..2ed03f7 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
@@ -54,9 +54,7 @@ private[netty] case class RemoteProcessConnectionError(cause: 
Throwable, remoteA
 /**
  * An inbox that stores messages for an [[RpcEndpoint]] and posts messages to 
it thread-safely.
  */
-private[netty] class Inbox(
-    val endpointRef: NettyRpcEndpointRef,
-    val endpoint: RpcEndpoint)
+private[netty] class Inbox(val endpointName: String, val endpoint: RpcEndpoint)
   extends Logging {
 
   inbox =>  // Give this an alias so we can use it more clearly in closures.
@@ -195,7 +193,7 @@ private[netty] class Inbox(
    * Exposed for testing.
    */
   protected def onDrop(message: InboxMessage): Unit = {
-    logWarning(s"Drop $message because $endpointRef is stopped")
+    logWarning(s"Drop $message because endpoint $endpointName is stopped")
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala
new file mode 100644
index 0000000..c985c72
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import java.util.concurrent._
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.EXECUTOR_ID
+import org.apache.spark.internal.config.Network._
+import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcEndpoint}
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * A message loop used by [[Dispatcher]] to deliver messages to endpoints.
+ */
+private sealed abstract class MessageLoop(dispatcher: Dispatcher) extends 
Logging {
+
+  // List of inboxes with pending messages, to be processed by the message 
loop.
+  private val active = new LinkedBlockingQueue[Inbox]()
+
+  // Message loop task; should be run in all threads of the message loop's 
pool.
+  protected val receiveLoopRunnable = new Runnable() {
+    override def run(): Unit = receiveLoop()
+  }
+
+  protected val threadpool: ExecutorService
+
+  private var stopped = false
+
+  def post(endpointName: String, message: InboxMessage): Unit
+
+  def unregister(name: String): Unit
+
+  def stop(): Unit = {
+    synchronized {
+      if (!stopped) {
+        setActive(MessageLoop.PoisonPill)
+        threadpool.shutdown()
+        stopped = true
+      }
+    }
+    threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
+  }
+
+  protected final def setActive(inbox: Inbox): Unit = active.offer(inbox)
+
+  private def receiveLoop(): Unit = {
+    try {
+      while (true) {
+        try {
+          val inbox = active.take()
+          if (inbox == MessageLoop.PoisonPill) {
+            // Put PoisonPill back so that other threads can see it.
+            setActive(MessageLoop.PoisonPill)
+            return
+          }
+          inbox.process(dispatcher)
+        } catch {
+          case NonFatal(e) => logError(e.getMessage, e)
+        }
+      }
+    } catch {
+      case _: InterruptedException => // exit
+        case t: Throwable =>
+          try {
+            // Re-submit a receive task so that message delivery will still 
work if
+            // UncaughtExceptionHandler decides to not kill JVM.
+            threadpool.execute(receiveLoopRunnable)
+          } finally {
+            throw t
+          }
+    }
+  }
+}
+
+private object MessageLoop {
+  /** A poison inbox that indicates the message loop should stop processing 
messages. */
+  val PoisonPill = new Inbox(null, null)
+}
+
+/**
+ * A message loop that serves multiple RPC endpoints, using a shared thread 
pool.
+ */
+private class SharedMessageLoop(
+    conf: SparkConf,
+    dispatcher: Dispatcher,
+    numUsableCores: Int)
+  extends MessageLoop(dispatcher) {
+
+  private val endpoints = new ConcurrentHashMap[String, Inbox]()
+
+  private def getNumOfThreads(conf: SparkConf): Int = {
+    val availableCores =
+      if (numUsableCores > 0) numUsableCores else 
Runtime.getRuntime.availableProcessors()
+
+    val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS)
+      .getOrElse(math.max(2, availableCores))
+
+    conf.get(EXECUTOR_ID).map { id =>
+      val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else 
"executor"
+      conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", 
modNumThreads)
+    }.getOrElse(modNumThreads)
+  }
+
+  /** Thread pool used for dispatching messages. */
+  override protected val threadpool: ThreadPoolExecutor = {
+    val numThreads = getNumOfThreads(conf)
+    val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, 
"dispatcher-event-loop")
+    for (i <- 0 until numThreads) {
+      pool.execute(receiveLoopRunnable)
+    }
+    pool
+  }
+
+  override def post(endpointName: String, message: InboxMessage): Unit = {
+    val inbox = endpoints.get(endpointName)
+    inbox.post(message)
+    setActive(inbox)
+  }
+
+  override def unregister(name: String): Unit = {
+    val inbox = endpoints.remove(name)
+    if (inbox != null) {
+      inbox.stop()
+      // Mark active to handle the OnStop message.
+      setActive(inbox)
+    }
+  }
+
+  def register(name: String, endpoint: RpcEndpoint): Unit = {
+    val inbox = new Inbox(name, endpoint)
+    endpoints.put(name, inbox)
+    // Mark active to handle the OnStart message.
+    setActive(inbox)
+  }
+}
+
+/**
+ * A message loop that is dedicated to a single RPC endpoint.
+ */
+private class DedicatedMessageLoop(
+    name: String,
+    endpoint: IsolatedRpcEndpoint,
+    dispatcher: Dispatcher)
+  extends MessageLoop(dispatcher) {
+
+  private val inbox = new Inbox(name, endpoint)
+
+  override protected val threadpool = if (endpoint.threadCount() > 1) {
+    ThreadUtils.newDaemonCachedThreadPool(s"dispatcher-$name", 
endpoint.threadCount())
+  } else {
+    ThreadUtils.newDaemonSingleThreadExecutor(s"dispatcher-$name")
+  }
+
+  (1 to endpoint.threadCount()).foreach { _ =>
+    threadpool.submit(receiveLoopRunnable)
+  }
+
+  // Mark active to handle the OnStart message.
+  setActive(inbox)
+
+  override def post(endpointName: String, message: InboxMessage): Unit = {
+    require(endpointName == name)
+    inbox.post(message)
+    setActive(inbox)
+  }
+
+  override def unregister(endpointName: String): Unit = synchronized {
+    require(endpointName == name)
+    inbox.stop()
+    // Mark active to handle the OnStop message.
+    setActive(inbox)
+    setActive(MessageLoop.PoisonPill)
+    threadpool.shutdown()
+  }
+}
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
 
b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 4958389..6e990d1 100644
--- 
a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ 
b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -111,7 +111,7 @@ class CoarseGrainedSchedulerBackend(scheduler: 
TaskSchedulerImpl, val rpcEnv: Rp
   private val reviveThread =
     ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread")
 
-  class DriverEndpoint extends ThreadSafeRpcEndpoint with Logging {
+  class DriverEndpoint extends IsolatedRpcEndpoint with Logging {
 
     override val rpcEnv: RpcEnv = CoarseGrainedSchedulerBackend.this.rpcEnv
 
diff --git 
a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index faf6f71..02d0e1a 100644
--- 
a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ 
b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -30,7 +30,7 @@ import org.apache.spark.SparkConf
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.network.shuffle.ExternalBlockStoreClient
-import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, 
ThreadSafeRpcEndpoint}
+import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, 
RpcEndpointRef, RpcEnv}
 import org.apache.spark.scheduler._
 import org.apache.spark.storage.BlockManagerMessages._
 import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils}
@@ -46,7 +46,7 @@ class BlockManagerMasterEndpoint(
     conf: SparkConf,
     listenerBus: LiveListenerBus,
     externalBlockStoreClient: Option[ExternalBlockStoreClient])
-  extends ThreadSafeRpcEndpoint with Logging {
+  extends IsolatedRpcEndpoint with Logging {
 
   // Mapping from block manager id to the block manager's information.
   private val blockManagerInfo = new mutable.HashMap[BlockManagerId, 
BlockManagerInfo]
diff --git 
a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
index f90595a..29e2114 100644
--- 
a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
+++ 
b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
@@ -21,7 +21,7 @@ import scala.concurrent.{ExecutionContext, Future}
 
 import org.apache.spark.{MapOutputTracker, SparkEnv}
 import org.apache.spark.internal.Logging
-import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
+import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEnv}
 import org.apache.spark.storage.BlockManagerMessages._
 import org.apache.spark.util.{ThreadUtils, Utils}
 
@@ -34,7 +34,7 @@ class BlockManagerSlaveEndpoint(
     override val rpcEnv: RpcEnv,
     blockManager: BlockManager,
     mapOutputTracker: MapOutputTracker)
-  extends ThreadSafeRpcEndpoint with Logging {
+  extends IsolatedRpcEndpoint with Logging {
 
   private val asyncThreadPool =
     
ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool", 
100)
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala 
b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 5929fbf..c10f2c2 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -36,7 +36,6 @@ import org.scalatest.concurrent.Eventually._
 import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, 
SparkFunSuite}
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.internal.config._
-import org.apache.spark.internal.config.Network
 import org.apache.spark.util.{ThreadUtils, Utils}
 
 /**
@@ -954,6 +953,40 @@ abstract class RpcEnvSuite extends SparkFunSuite with 
BeforeAndAfterAll {
     verify(endpoint, never()).onDisconnected(any())
     verify(endpoint, never()).onNetworkError(any(), any())
   }
+
+  test("isolated endpoints") {
+    val latch = new CountDownLatch(1)
+    val singleThreadedEnv = createRpcEnv(
+      new SparkConf().set(Network.RPC_NETTY_DISPATCHER_NUM_THREADS, 1), 
"singleThread", 0)
+    try {
+      val blockingEndpoint = singleThreadedEnv.setupEndpoint("blocking", new 
IsolatedRpcEndpoint {
+        override val rpcEnv: RpcEnv = singleThreadedEnv
+
+        override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+          case m =>
+            latch.await()
+            context.reply(m)
+        }
+      })
+
+      val nonBlockingEndpoint = 
singleThreadedEnv.setupEndpoint("non-blocking", new RpcEndpoint {
+        override val rpcEnv: RpcEnv = singleThreadedEnv
+
+        override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+          case m => context.reply(m)
+        }
+      })
+
+      val to = new RpcTimeout(5.seconds, "test-timeout")
+      val blockingFuture = blockingEndpoint.ask[String]("hi", to)
+      assert(nonBlockingEndpoint.askSync[String]("hello", to) === "hello")
+      latch.countDown()
+      assert(ThreadUtils.awaitResult(blockingFuture, 5.seconds) === "hi")
+    } finally {
+      latch.countDown()
+      singleThreadedEnv.shutdown()
+    }
+  }
 }
 
 class UnserializableClass
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala 
b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
index e553956..c74c728 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
@@ -29,12 +29,9 @@ class InboxSuite extends SparkFunSuite {
 
   test("post") {
     val endpoint = new TestRpcEndpoint
-    val endpointRef = mock(classOf[NettyRpcEndpointRef])
-    when(endpointRef.name).thenReturn("hello")
-
     val dispatcher = mock(classOf[Dispatcher])
 
-    val inbox = new Inbox(endpointRef, endpoint)
+    val inbox = new Inbox("name", endpoint)
     val message = OneWayMessage(null, "hi")
     inbox.post(message)
     inbox.process(dispatcher)
@@ -51,10 +48,9 @@ class InboxSuite extends SparkFunSuite {
 
   test("post: with reply") {
     val endpoint = new TestRpcEndpoint
-    val endpointRef = mock(classOf[NettyRpcEndpointRef])
     val dispatcher = mock(classOf[Dispatcher])
 
-    val inbox = new Inbox(endpointRef, endpoint)
+    val inbox = new Inbox("name", endpoint)
     val message = RpcMessage(null, "hi", null)
     inbox.post(message)
     inbox.process(dispatcher)
@@ -65,13 +61,10 @@ class InboxSuite extends SparkFunSuite {
 
   test("post: multiple threads") {
     val endpoint = new TestRpcEndpoint
-    val endpointRef = mock(classOf[NettyRpcEndpointRef])
-    when(endpointRef.name).thenReturn("hello")
-
     val dispatcher = mock(classOf[Dispatcher])
 
     val numDroppedMessages = new AtomicInteger(0)
-    val inbox = new Inbox(endpointRef, endpoint) {
+    val inbox = new Inbox("name", endpoint) {
       override def onDrop(message: InboxMessage): Unit = {
         numDroppedMessages.incrementAndGet()
       }
@@ -107,12 +100,10 @@ class InboxSuite extends SparkFunSuite {
 
   test("post: Associated") {
     val endpoint = new TestRpcEndpoint
-    val endpointRef = mock(classOf[NettyRpcEndpointRef])
     val dispatcher = mock(classOf[Dispatcher])
-
     val remoteAddress = RpcAddress("localhost", 11111)
 
-    val inbox = new Inbox(endpointRef, endpoint)
+    val inbox = new Inbox("name", endpoint)
     inbox.post(RemoteProcessConnected(remoteAddress))
     inbox.process(dispatcher)
 
@@ -121,12 +112,11 @@ class InboxSuite extends SparkFunSuite {
 
   test("post: Disassociated") {
     val endpoint = new TestRpcEndpoint
-    val endpointRef = mock(classOf[NettyRpcEndpointRef])
     val dispatcher = mock(classOf[Dispatcher])
 
     val remoteAddress = RpcAddress("localhost", 11111)
 
-    val inbox = new Inbox(endpointRef, endpoint)
+    val inbox = new Inbox("name", endpoint)
     inbox.post(RemoteProcessDisconnected(remoteAddress))
     inbox.process(dispatcher)
 
@@ -135,13 +125,12 @@ class InboxSuite extends SparkFunSuite {
 
   test("post: AssociationError") {
     val endpoint = new TestRpcEndpoint
-    val endpointRef = mock(classOf[NettyRpcEndpointRef])
     val dispatcher = mock(classOf[Dispatcher])
 
     val remoteAddress = RpcAddress("localhost", 11111)
     val cause = new RuntimeException("Oops")
 
-    val inbox = new Inbox(endpointRef, endpoint)
+    val inbox = new Inbox("name", endpoint)
     inbox.post(RemoteProcessConnectionError(cause, remoteAddress))
     inbox.process(dispatcher)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to