This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 121395f0f [CELEBORN-1314] add capacity-bounded inbox for rpc endpoint
121395f0f is described below
commit 121395f0f53656b21b22517e69fabc79cc4ef2ca
Author: CodingCat <[email protected]>
AuthorDate: Tue Apr 16 10:56:32 2024 +0800
[CELEBORN-1314] add capacity-bounded inbox for rpc endpoint
### What changes were proposed in this pull request?
we found a lot of driver OOM issue when dealing with spark applications
with super large shuffle,
with the heap dump, we found the inbox of rpc endpoints accumulated tons of
change partition location message... even we have increased splitThreshold to
10G, many jobs still have this issue (keep increasing this value will increase
the risk of disk overusage of workers)
This PR implements capacity-bounded inbox which is based on a
LinkedBlockingQueue with a configured capacity, we found it effectively
resolves the problem for us
### Why are the changes needed?
the following screenshots show the main memory consumer in Driver side
<img width="661" alt="image"
src="https://github.com/apache/incubator-celeborn/assets/678008/d63196cc-6c3c-4b32-a9db-9871e7cb5fd8">
<img width="723" alt="image"
src="https://github.com/apache/incubator-celeborn/assets/678008/64a506c4-03ea-4932-98ba-f8f4923daa6e">
### Does this PR introduce _any_ user-facing change?
no, but two more configurations
### How was this patch tested?
integration tests and unit tests
screenshot showing the application driver memory usage with the patch (blue
line)
<img width="766" alt="image"
src="https://github.com/apache/incubator-celeborn/assets/678008/86ecaba8-c164-4aef-ad83-cee03238e5da">
screenshot showing the application driver memory usage without patch (brown
line)
<img width="799" alt="image"
src="https://github.com/apache/incubator-celeborn/assets/678008/a012e0ba-0292-4d25-a7b9-252bdc3cb8cb">
Closes #2366 from CodingCat/memory_bounded_driver.
Authored-by: CodingCat <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
---
.../org/apache/celeborn/common/CelebornConf.scala | 14 ++
.../celeborn/common/rpc/netty/Dispatcher.scala | 10 +-
.../apache/celeborn/common/rpc/netty/Inbox.scala | 249 ++++++++++++++-------
.../celeborn/common/rpc/netty/InboxSuite.scala | 54 +++--
docs/configuration/network.md | 1 +
5 files changed, 220 insertions(+), 108 deletions(-)
diff --git
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index d6acd09a4..bb6b96e93 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -400,6 +400,9 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable
with Logging with Se
new RpcTimeout(get(RPC_LOOKUP_TIMEOUT).milli, RPC_LOOKUP_TIMEOUT.key)
def rpcAskTimeout: RpcTimeout =
new RpcTimeout(get(RPC_ASK_TIMEOUT).milli, RPC_ASK_TIMEOUT.key)
+ def rpcInMemoryBoundedInboxCapacity(): Int = {
+ get(RPC_INBOX_CAPACITY)
+ }
def rpcDispatcherNumThreads(availableCores: Int): Int = {
val num = get(RPC_DISPATCHER_THREADS)
if (num != 0) num else availableCores
@@ -1592,6 +1595,17 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(0)
+ val RPC_INBOX_CAPACITY: ConfigEntry[Int] =
+ buildConf("celeborn.rpc.inbox.capacity")
+ .categories("network")
+ .doc("Specifies size of the in memory bounded capacity.")
+ .version("0.5.0")
+ .intConf
+ .checkValue(
+ v => v >= 0,
+ "the capacity of inbox must be no less than 0, 0 means no limitation")
+ .createWithDefault(0)
+
val RPC_ROLE_DISPATHER_THREADS: ConfigEntry[Int] =
buildConf("celeborn.<role>.rpc.dispatcher.threads")
.categories("network")
diff --git
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala
index 391b64186..b8a93bda0 100644
---
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala
@@ -39,7 +39,8 @@ private[celeborn] class Dispatcher(nettyEnv: NettyRpcEnv)
extends Logging {
val name: String,
val endpoint: RpcEndpoint,
val ref: NettyRpcEndpointRef) {
- val inbox = new Inbox(ref, endpoint)
+ val celebornConf = nettyEnv.celebornConf
+ val inbox = new Inbox(ref, endpoint, celebornConf)
}
private val endpoints: ConcurrentMap[String, EndpointData] =
@@ -157,7 +158,14 @@ private[celeborn] class Dispatcher(nettyEnv: NettyRpcEnv)
extends Logging {
endpointName: String,
message: InboxMessage,
callbackIfStopped: Exception => Unit): Unit = {
+ val data = synchronized {
+ endpoints.get(endpointName)
+ }
+ if (data != null) {
+ data.inbox.waitOnFull()
+ }
val error = synchronized {
+ // double check
val data = endpoints.get(endpointName)
if (stopped) {
Some(new RpcEnvStoppedException())
diff --git
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
index 09cdd08e2..750252456 100644
--- a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
@@ -17,10 +17,13 @@
package org.apache.celeborn.common.rpc.netty
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.locks.ReentrantLock
import javax.annotation.concurrent.GuardedBy
import scala.util.control.NonFatal
+import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.CelebornException
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.rpc.{RpcAddress, RpcEndpoint,
ThreadSafeRpcEndpoint}
@@ -64,14 +67,21 @@ private[celeborn] case class RemoteProcessConnectionError(
*/
private[celeborn] class Inbox(
val endpointRef: NettyRpcEndpointRef,
- val endpoint: RpcEndpoint)
- extends Logging {
+ val endpoint: RpcEndpoint,
+ val conf: CelebornConf) extends Logging {
inbox => // Give this an alias so we can use it more clearly in closures.
+ private[netty] val capacity = conf.get(CelebornConf.RPC_INBOX_CAPACITY)
+
+ private[netty] val inboxLock = new ReentrantLock()
+ private[netty] val isFull = inboxLock.newCondition()
+
@GuardedBy("this")
protected val messages = new java.util.LinkedList[InboxMessage]()
+ private val messageCount = new AtomicLong(0)
+
/** True if the inbox (and its associated endpoint) is stopped. */
@GuardedBy("this")
private var stopped = false
@@ -85,84 +95,130 @@ private[celeborn] class Inbox(
private var numActiveThreads = 0
// OnStart should be the first message to process
- inbox.synchronized {
+ try {
+ inboxLock.lockInterruptibly()
messages.add(OnStart)
+ messageCount.incrementAndGet()
+ } finally {
+ inboxLock.unlock()
+ }
+
+ def addMessage(message: InboxMessage): Unit = {
+ messages.add(message)
+ messageCount.incrementAndGet()
+ signalNotFull()
+ logDebug(s"queue length of ${messageCount.get()} ")
+ }
+
+ private def processInternal(dispatcher: Dispatcher, message: InboxMessage):
Unit = {
+ message match {
+ case RpcMessage(_sender, content, context) =>
+ try {
+ endpoint.receiveAndReply(context).applyOrElse[Any, Unit](
+ content,
+ { msg =>
+ throw new CelebornException(s"Unsupported message $message from
${_sender}")
+ })
+ } catch {
+ case e: Throwable =>
+ context.sendFailure(e)
+ // Throw the exception -- this exception will be caught by the
safelyCall function.
+ // The endpoint's onError function will be called.
+ throw e
+ }
+
+ case OneWayMessage(_sender, content) =>
+ endpoint.receive.applyOrElse[Any, Unit](
+ content,
+ { msg =>
+ throw new CelebornException(s"Unsupported message $message from
${_sender}")
+ })
+
+ case OnStart =>
+ endpoint.onStart()
+ if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
+ try {
+ inboxLock.lockInterruptibly()
+ if (!stopped) {
+ enableConcurrent = true
+ }
+ } finally {
+ inboxLock.unlock()
+ }
+ }
+
+ case OnStop =>
+ val activeThreads =
+ try {
+ inboxLock.lockInterruptibly()
+ inbox.numActiveThreads
+ } finally {
+ inboxLock.unlock()
+ }
+ assert(
+ activeThreads == 1,
+ s"There should be only a single active thread but found
$activeThreads threads.")
+ dispatcher.removeRpcEndpointRef(endpoint)
+ endpoint.onStop()
+ assert(isEmpty, "OnStop should be the last message")
+
+ case RemoteProcessConnected(remoteAddress) =>
+ endpoint.onConnected(remoteAddress)
+
+ case RemoteProcessDisconnected(remoteAddress) =>
+ endpoint.onDisconnected(remoteAddress)
+
+ case RemoteProcessConnectionError(cause, remoteAddress) =>
+ endpoint.onNetworkError(cause, remoteAddress)
+ }
+ }
+
+ private[netty] def waitOnFull(): Unit = {
+ if (capacity > 0 && !stopped) {
+ try {
+ inboxLock.lockInterruptibly()
+ while (messageCount.get() >= capacity) {
+ isFull.await()
+ }
+ } finally {
+ inboxLock.unlock()
+ }
+ }
+ }
+
+ private def signalNotFull(): Unit = {
+ // when this is called we assume putLock already being called
+ require(inboxLock.isHeldByCurrentThread, "cannot call signalNotFull
without holding lock")
+ if (capacity > 0 && messageCount.get() < capacity) {
+ isFull.signal()
+ }
}
- /**
- * Process stored messages.
- */
def process(dispatcher: Dispatcher): Unit = {
var message: InboxMessage = null
- inbox.synchronized {
+ try {
+ inboxLock.lockInterruptibly()
if (!enableConcurrent && numActiveThreads != 0) {
return
}
message = messages.poll()
if (message != null) {
numActiveThreads += 1
+ messageCount.decrementAndGet()
+ signalNotFull()
} else {
return
}
+ } finally {
+ inboxLock.unlock()
}
+
while (true) {
safelyCall(endpoint, endpointRef.name) {
- message match {
- case RpcMessage(_sender, content, context) =>
- try {
- endpoint.receiveAndReply(context).applyOrElse[Any, Unit](
- content,
- { msg =>
- throw new CelebornException(s"Unsupported message $message
from ${_sender}")
- })
- } catch {
- case e: Throwable =>
- context.sendFailure(e)
- // Throw the exception -- this exception will be caught by the
safelyCall function.
- // The endpoint's onError function will be called.
- throw e
- }
-
- case OneWayMessage(_sender, content) =>
- endpoint.receive.applyOrElse[Any, Unit](
- content,
- { msg =>
- throw new CelebornException(s"Unsupported message $message
from ${_sender}")
- })
-
- case OnStart =>
- endpoint.onStart()
- if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
- inbox.synchronized {
- if (!stopped) {
- enableConcurrent = true
- }
- }
- }
-
- case OnStop =>
- val activeThreads = inbox.synchronized {
- inbox.numActiveThreads
- }
- assert(
- activeThreads == 1,
- s"There should be only a single active thread but found
$activeThreads threads.")
- dispatcher.removeRpcEndpointRef(endpoint)
- endpoint.onStop()
- assert(isEmpty, "OnStop should be the last message")
-
- case RemoteProcessConnected(remoteAddress) =>
- endpoint.onConnected(remoteAddress)
-
- case RemoteProcessDisconnected(remoteAddress) =>
- endpoint.onDisconnected(remoteAddress)
-
- case RemoteProcessConnectionError(cause, remoteAddress) =>
- endpoint.onNetworkError(cause, remoteAddress)
- }
+ processInternal(dispatcher, message)
}
-
- inbox.synchronized {
+ try {
+ inboxLock.lockInterruptibly()
// "enableConcurrent" will be set to false after `onStop` is called,
so we should check it
// every time.
if (!enableConcurrent && numActiveThreads != 1) {
@@ -174,37 +230,56 @@ private[celeborn] class Inbox(
if (message == null) {
numActiveThreads -= 1
return
+ } else {
+ messageCount.decrementAndGet()
+ signalNotFull()
}
+ } finally {
+ inboxLock.unlock()
}
}
}
- def post(message: InboxMessage): Unit = inbox.synchronized {
- if (stopped) {
- // We already put "OnStop" into "messages", so we should drop further
messages
- onDrop(message)
- } else {
- messages.add(message)
- false
+ def post(message: InboxMessage): Unit = {
+ try {
+ inboxLock.lockInterruptibly()
+ if (stopped) {
+ // We already put "OnStop" into "messages", so we should drop further
messages
+ onDrop(message)
+ } else {
+ addMessage(message)
+ }
+ } finally {
+ inboxLock.unlock()
}
}
- def stop(): Unit = inbox.synchronized {
- // The following codes should be in `synchronized` so that we can make
sure "OnStop" is the last
- // message
- if (!stopped) {
- // We should disable concurrent here. Then when RpcEndpoint.onStop is
called, it's the only
- // thread that is processing messages. So `RpcEndpoint.onStop` can
release its resources
- // safely.
- enableConcurrent = false
- stopped = true
- messages.add(OnStop)
- // Note: The concurrent events in messages will be processed one by one.
+ def stop(): Unit = {
+ try {
+ inboxLock.lockInterruptibly()
+ // The following codes should be in `synchronized` so that we can make
sure "OnStop" is the last
+ // message
+ if (!stopped) {
+ // We should disable concurrent here. Then when RpcEndpoint.onStop is
called, it's the only
+ // thread that is processing messages. So `RpcEndpoint.onStop` can
release its resources
+ // safely.
+ enableConcurrent = false
+ stopped = true
+ addMessage(OnStop)
+ // Note: The concurrent events in messages will be processed one by
one.
+ }
+ } finally {
+ inboxLock.unlock()
}
}
- def isEmpty: Boolean = inbox.synchronized {
- messages.isEmpty
+ def isEmpty: Boolean = {
+ try {
+ inboxLock.lockInterruptibly()
+ messages.isEmpty
+ } finally {
+ inboxLock.unlock()
+ }
}
/**
@@ -222,10 +297,13 @@ private[celeborn] class Inbox(
endpoint: RpcEndpoint,
endpointRefName: String)(action: => Unit): Unit = {
def dealWithFatalError(fatal: Throwable): Unit = {
- inbox.synchronized {
+ try {
+ inboxLock.lockInterruptibly()
assert(numActiveThreads > 0, "The number of active threads should be
positive.")
// Should reduce the number of active threads before throw the error.
numActiveThreads -= 1
+ } finally {
+ inboxLock.unlock()
}
logError(
s"An error happened while processing message in the inbox for
$endpointRefName",
@@ -254,8 +332,11 @@ private[celeborn] class Inbox(
// exposed only for testing
def getNumActiveThreads: Int = {
- inbox.synchronized {
+ try {
+ inboxLock.lockInterruptibly()
inbox.numActiveThreads
+ } finally {
+ inboxLock.unlock()
}
}
}
diff --git
a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala
index a8bc826dd..ab86a57e8 100644
---
a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala
+++
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala
@@ -21,18 +21,40 @@ import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
import org.mockito.Mockito._
+import org.scalatest.BeforeAndAfter
import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.rpc.{RpcAddress, TestRpcEndpoint}
-class InboxSuite extends CelebornFunSuite {
+class InboxSuite extends CelebornFunSuite with BeforeAndAfter {
- test("post") {
- val endpoint = new TestRpcEndpoint
+ private var inbox: Inbox = _
+ private var endpoint: TestRpcEndpoint = _
+
+ def initInbox[T](
+ testRpcEndpoint: TestRpcEndpoint,
+ onDropOverride: Option[InboxMessage => T]): Inbox = {
val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+ if (onDropOverride.isEmpty) {
+ new Inbox(rpcEnvRef, testRpcEndpoint, new CelebornConf())
+ } else {
+ new Inbox(rpcEnvRef, testRpcEndpoint, new CelebornConf()) {
+ override protected def onDrop(message: InboxMessage): Unit = {
+ onDropOverride.get(message)
+ }
+ }
+ }
+ }
+
+ before {
+ endpoint = new TestRpcEndpoint
+ inbox = initInbox(endpoint, None)
+ }
+
+ test("post") {
val dispatcher = mock(classOf[Dispatcher])
- val inbox = new Inbox(rpcEnvRef, endpoint)
val message = OneWayMessage(null, "hi")
inbox.post(message)
inbox.process(dispatcher)
@@ -48,11 +70,8 @@ class InboxSuite extends CelebornFunSuite {
}
test("post: with reply") {
- val endpoint = new TestRpcEndpoint
- val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
val dispatcher = mock(classOf[Dispatcher])
- val inbox = new Inbox(rpcEnvRef, endpoint)
val message = RpcMessage(null, "hi", null)
inbox.post(message)
inbox.process(dispatcher)
@@ -62,16 +81,15 @@ class InboxSuite extends CelebornFunSuite {
}
test("post: multiple threads") {
- val endpoint = new TestRpcEndpoint
val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
val dispatcher = mock(classOf[Dispatcher])
val numDroppedMessages = new AtomicInteger(0)
- val inbox = new Inbox(rpcEnvRef, endpoint) {
- override def onDrop(message: InboxMessage): Unit = {
- numDroppedMessages.incrementAndGet()
- }
+
+ val overrideOnDrop = (msg: InboxMessage) => {
+ numDroppedMessages.incrementAndGet()
}
+ val inbox = initInbox(endpoint, Some(overrideOnDrop))
val exitLatch = new CountDownLatch(10)
@@ -102,12 +120,9 @@ class InboxSuite extends CelebornFunSuite {
}
test("post: Associated") {
- val endpoint = new TestRpcEndpoint
- val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
val dispatcher = mock(classOf[Dispatcher])
val remoteAddress = RpcAddress("localhost", 11111)
- val inbox = new Inbox(rpcEnvRef, endpoint)
inbox.post(RemoteProcessConnected(remoteAddress))
inbox.process(dispatcher)
@@ -115,13 +130,10 @@ class InboxSuite extends CelebornFunSuite {
}
test("post: Disassociated") {
- val endpoint = new TestRpcEndpoint
- val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
val dispatcher = mock(classOf[Dispatcher])
val remoteAddress = RpcAddress("localhost", 11111)
- val inbox = new Inbox(rpcEnvRef, endpoint)
inbox.post(RemoteProcessDisconnected(remoteAddress))
inbox.process(dispatcher)
@@ -129,14 +141,11 @@ class InboxSuite extends CelebornFunSuite {
}
test("post: AssociationError") {
- val endpoint = new TestRpcEndpoint
- val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
val dispatcher = mock(classOf[Dispatcher])
val remoteAddress = RpcAddress("localhost", 11111)
val cause = new RuntimeException("Oops")
- val inbox = new Inbox(rpcEnvRef, endpoint)
inbox.post(RemoteProcessConnectionError(cause, remoteAddress))
inbox.process(dispatcher)
@@ -146,9 +155,8 @@ class InboxSuite extends CelebornFunSuite {
test("should reduce the number of active threads when fatal error happens") {
val endpoint = mock(classOf[TestRpcEndpoint])
when(endpoint.receive).thenThrow(new OutOfMemoryError())
- val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
val dispatcher = mock(classOf[Dispatcher])
- val inbox = new Inbox(rpcEnvRef, endpoint)
+ val inbox = initInbox(endpoint, None)
inbox.post(OneWayMessage(null, "hi"))
intercept[OutOfMemoryError] {
inbox.process(dispatcher)
diff --git a/docs/configuration/network.md b/docs/configuration/network.md
index a295d5353..e0e3e8e11 100644
--- a/docs/configuration/network.md
+++ b/docs/configuration/network.md
@@ -50,6 +50,7 @@ license: |
| celeborn.rpc.askTimeout | 60s | false | Timeout for RPC ask operations. It's
recommended to set at least `240s` when `HDFS` is enabled in
`celeborn.storage.activeTypes` | 0.2.0 | |
| celeborn.rpc.connect.threads | 64 | false | | 0.2.0 | |
| celeborn.rpc.dispatcher.threads | 0 | false | Threads number of message
dispatcher event loop. Default to 0, which is availableCore. | 0.3.0 |
celeborn.rpc.dispatcher.numThreads |
+| celeborn.rpc.inbox.capacity | 0 | false | Specifies size of the in memory
bounded capacity. | 0.5.0 | |
| celeborn.rpc.io.threads | <undefined> | false | Netty IO thread number
of NettyRpcEnv to handle RPC request. The default threads number is the number
of runtime available processors. | 0.2.0 | |
| celeborn.rpc.lookupTimeout | 30s | false | Timeout for RPC lookup
operations. | 0.2.0 | |
| celeborn.shuffle.io.maxChunksBeingTransferred | <undefined> | false |
The max number of chunks allowed to be transferred at the same time on shuffle
service. Note that new incoming connections will be closed when the max number
is hit. The client will retry according to the shuffle retry configs (see
`celeborn.<module>.io.maxRetries` and `celeborn.<module>.io.retryWait`), if
those limits are reached the task will fail with fetch failure. | 0.2.0 | |