This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 121395f0f [CELEBORN-1314] add capacity-bounded inbox for rpc endpoint
121395f0f is described below

commit 121395f0f53656b21b22517e69fabc79cc4ef2ca
Author: CodingCat <[email protected]>
AuthorDate: Tue Apr 16 10:56:32 2024 +0800

    [CELEBORN-1314] add capacity-bounded inbox for rpc endpoint
    
    ### What changes were proposed in this pull request?
    
    we found a lot of driver OOM issue when dealing with spark applications 
with super large shuffle,
    
    with the heap dump, we found the inbox of rpc endpoints accumulated tons of 
change partition location message... even we have increased splitThreshold to 
10G, many jobs still have this issue (keep increasing this value will increase 
the risk of disk overusage of workers)
    
    This PR implements capacity-bounded inbox which is based on a 
LinkedBlockingQueue with a configured capacity, we found it effectively 
resolves the problem for us
    
    ### Why are the changes needed?
    
    the following screenshots show the main memory consumer in Driver side
    
    <img width="661" alt="image" 
src="https://github.com/apache/incubator-celeborn/assets/678008/d63196cc-6c3c-4b32-a9db-9871e7cb5fd8";>
    <img width="723" alt="image" 
src="https://github.com/apache/incubator-celeborn/assets/678008/64a506c4-03ea-4932-98ba-f8f4923daa6e";>
    
    ### Does this PR introduce _any_ user-facing change?
    
    no, but two more configurations
    
    ### How was this patch tested?
    
    integration tests and unit tests
    
    screenshot showing the application driver memory usage with the patch (blue 
line)
    
    <img width="766" alt="image" 
src="https://github.com/apache/incubator-celeborn/assets/678008/86ecaba8-c164-4aef-ad83-cee03238e5da";>
    
    screenshot showing the application driver memory usage without patch (brown 
line)
    
    <img width="799" alt="image" 
src="https://github.com/apache/incubator-celeborn/assets/678008/a012e0ba-0292-4d25-a7b9-252bdc3cb8cb";>
    
    Closes #2366 from CodingCat/memory_bounded_driver.
    
    Authored-by: CodingCat <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../org/apache/celeborn/common/CelebornConf.scala  |  14 ++
 .../celeborn/common/rpc/netty/Dispatcher.scala     |  10 +-
 .../apache/celeborn/common/rpc/netty/Inbox.scala   | 249 ++++++++++++++-------
 .../celeborn/common/rpc/netty/InboxSuite.scala     |  54 +++--
 docs/configuration/network.md                      |   1 +
 5 files changed, 220 insertions(+), 108 deletions(-)

diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index d6acd09a4..bb6b96e93 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -400,6 +400,9 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable 
with Logging with Se
     new RpcTimeout(get(RPC_LOOKUP_TIMEOUT).milli, RPC_LOOKUP_TIMEOUT.key)
   def rpcAskTimeout: RpcTimeout =
     new RpcTimeout(get(RPC_ASK_TIMEOUT).milli, RPC_ASK_TIMEOUT.key)
+  def rpcInMemoryBoundedInboxCapacity(): Int = {
+    get(RPC_INBOX_CAPACITY)
+  }
   def rpcDispatcherNumThreads(availableCores: Int): Int = {
     val num = get(RPC_DISPATCHER_THREADS)
     if (num != 0) num else availableCores
@@ -1592,6 +1595,17 @@ object CelebornConf extends Logging {
       .intConf
       .createWithDefault(0)
 
+  val RPC_INBOX_CAPACITY: ConfigEntry[Int] =
+    buildConf("celeborn.rpc.inbox.capacity")
+      .categories("network")
+      .doc("Specifies size of the in memory bounded capacity.")
+      .version("0.5.0")
+      .intConf
+      .checkValue(
+        v => v >= 0,
+        "the capacity of inbox must be no less than 0, 0 means no limitation")
+      .createWithDefault(0)
+
   val RPC_ROLE_DISPATHER_THREADS: ConfigEntry[Int] =
     buildConf("celeborn.<role>.rpc.dispatcher.threads")
       .categories("network")
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala
index 391b64186..b8a93bda0 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala
@@ -39,7 +39,8 @@ private[celeborn] class Dispatcher(nettyEnv: NettyRpcEnv) 
extends Logging {
       val name: String,
       val endpoint: RpcEndpoint,
       val ref: NettyRpcEndpointRef) {
-    val inbox = new Inbox(ref, endpoint)
+    val celebornConf = nettyEnv.celebornConf
+    val inbox = new Inbox(ref, endpoint, celebornConf)
   }
 
   private val endpoints: ConcurrentMap[String, EndpointData] =
@@ -157,7 +158,14 @@ private[celeborn] class Dispatcher(nettyEnv: NettyRpcEnv) 
extends Logging {
       endpointName: String,
       message: InboxMessage,
       callbackIfStopped: Exception => Unit): Unit = {
+    val data = synchronized {
+      endpoints.get(endpointName)
+    }
+    if (data != null) {
+      data.inbox.waitOnFull()
+    }
     val error = synchronized {
+      // double check
       val data = endpoints.get(endpointName)
       if (stopped) {
         Some(new RpcEnvStoppedException())
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
index 09cdd08e2..750252456 100644
--- a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
@@ -17,10 +17,13 @@
 
 package org.apache.celeborn.common.rpc.netty
 
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.locks.ReentrantLock
 import javax.annotation.concurrent.GuardedBy
 
 import scala.util.control.NonFatal
 
+import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.exception.CelebornException
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.rpc.{RpcAddress, RpcEndpoint, 
ThreadSafeRpcEndpoint}
@@ -64,14 +67,21 @@ private[celeborn] case class RemoteProcessConnectionError(
  */
 private[celeborn] class Inbox(
     val endpointRef: NettyRpcEndpointRef,
-    val endpoint: RpcEndpoint)
-  extends Logging {
+    val endpoint: RpcEndpoint,
+    val conf: CelebornConf) extends Logging {
 
   inbox => // Give this an alias so we can use it more clearly in closures.
 
+  private[netty] val capacity = conf.get(CelebornConf.RPC_INBOX_CAPACITY)
+
+  private[netty] val inboxLock = new ReentrantLock()
+  private[netty] val isFull = inboxLock.newCondition()
+
   @GuardedBy("this")
   protected val messages = new java.util.LinkedList[InboxMessage]()
 
+  private val messageCount = new AtomicLong(0)
+
   /** True if the inbox (and its associated endpoint) is stopped. */
   @GuardedBy("this")
   private var stopped = false
@@ -85,84 +95,130 @@ private[celeborn] class Inbox(
   private var numActiveThreads = 0
 
   // OnStart should be the first message to process
-  inbox.synchronized {
+  try {
+    inboxLock.lockInterruptibly()
     messages.add(OnStart)
+    messageCount.incrementAndGet()
+  } finally {
+    inboxLock.unlock()
+  }
+
+  def addMessage(message: InboxMessage): Unit = {
+    messages.add(message)
+    messageCount.incrementAndGet()
+    signalNotFull()
+    logDebug(s"queue length of ${messageCount.get()} ")
+  }
+
+  private def processInternal(dispatcher: Dispatcher, message: InboxMessage): 
Unit = {
+    message match {
+      case RpcMessage(_sender, content, context) =>
+        try {
+          endpoint.receiveAndReply(context).applyOrElse[Any, Unit](
+            content,
+            { msg =>
+              throw new CelebornException(s"Unsupported message $message from 
${_sender}")
+            })
+        } catch {
+          case e: Throwable =>
+            context.sendFailure(e)
+            // Throw the exception -- this exception will be caught by the 
safelyCall function.
+            // The endpoint's onError function will be called.
+            throw e
+        }
+
+      case OneWayMessage(_sender, content) =>
+        endpoint.receive.applyOrElse[Any, Unit](
+          content,
+          { msg =>
+            throw new CelebornException(s"Unsupported message $message from 
${_sender}")
+          })
+
+      case OnStart =>
+        endpoint.onStart()
+        if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
+          try {
+            inboxLock.lockInterruptibly()
+            if (!stopped) {
+              enableConcurrent = true
+            }
+          } finally {
+            inboxLock.unlock()
+          }
+        }
+
+      case OnStop =>
+        val activeThreads =
+          try {
+            inboxLock.lockInterruptibly()
+            inbox.numActiveThreads
+          } finally {
+            inboxLock.unlock()
+          }
+        assert(
+          activeThreads == 1,
+          s"There should be only a single active thread but found 
$activeThreads threads.")
+        dispatcher.removeRpcEndpointRef(endpoint)
+        endpoint.onStop()
+        assert(isEmpty, "OnStop should be the last message")
+
+      case RemoteProcessConnected(remoteAddress) =>
+        endpoint.onConnected(remoteAddress)
+
+      case RemoteProcessDisconnected(remoteAddress) =>
+        endpoint.onDisconnected(remoteAddress)
+
+      case RemoteProcessConnectionError(cause, remoteAddress) =>
+        endpoint.onNetworkError(cause, remoteAddress)
+    }
+  }
+
+  private[netty] def waitOnFull(): Unit = {
+    if (capacity > 0 && !stopped) {
+      try {
+        inboxLock.lockInterruptibly()
+        while (messageCount.get() >= capacity) {
+          isFull.await()
+        }
+      } finally {
+        inboxLock.unlock()
+      }
+    }
+  }
+
+  private def signalNotFull(): Unit = {
+    // when this is called we assume putLock already being called
+    require(inboxLock.isHeldByCurrentThread, "cannot call signalNotFull 
without holding lock")
+    if (capacity > 0 && messageCount.get() < capacity) {
+      isFull.signal()
+    }
   }
 
-  /**
-   * Process stored messages.
-   */
   def process(dispatcher: Dispatcher): Unit = {
     var message: InboxMessage = null
-    inbox.synchronized {
+    try {
+      inboxLock.lockInterruptibly()
       if (!enableConcurrent && numActiveThreads != 0) {
         return
       }
       message = messages.poll()
       if (message != null) {
         numActiveThreads += 1
+        messageCount.decrementAndGet()
+        signalNotFull()
       } else {
         return
       }
+    } finally {
+      inboxLock.unlock()
     }
+
     while (true) {
       safelyCall(endpoint, endpointRef.name) {
-        message match {
-          case RpcMessage(_sender, content, context) =>
-            try {
-              endpoint.receiveAndReply(context).applyOrElse[Any, Unit](
-                content,
-                { msg =>
-                  throw new CelebornException(s"Unsupported message $message 
from ${_sender}")
-                })
-            } catch {
-              case e: Throwable =>
-                context.sendFailure(e)
-                // Throw the exception -- this exception will be caught by the 
safelyCall function.
-                // The endpoint's onError function will be called.
-                throw e
-            }
-
-          case OneWayMessage(_sender, content) =>
-            endpoint.receive.applyOrElse[Any, Unit](
-              content,
-              { msg =>
-                throw new CelebornException(s"Unsupported message $message 
from ${_sender}")
-              })
-
-          case OnStart =>
-            endpoint.onStart()
-            if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
-              inbox.synchronized {
-                if (!stopped) {
-                  enableConcurrent = true
-                }
-              }
-            }
-
-          case OnStop =>
-            val activeThreads = inbox.synchronized {
-              inbox.numActiveThreads
-            }
-            assert(
-              activeThreads == 1,
-              s"There should be only a single active thread but found 
$activeThreads threads.")
-            dispatcher.removeRpcEndpointRef(endpoint)
-            endpoint.onStop()
-            assert(isEmpty, "OnStop should be the last message")
-
-          case RemoteProcessConnected(remoteAddress) =>
-            endpoint.onConnected(remoteAddress)
-
-          case RemoteProcessDisconnected(remoteAddress) =>
-            endpoint.onDisconnected(remoteAddress)
-
-          case RemoteProcessConnectionError(cause, remoteAddress) =>
-            endpoint.onNetworkError(cause, remoteAddress)
-        }
+        processInternal(dispatcher, message)
       }
-
-      inbox.synchronized {
+      try {
+        inboxLock.lockInterruptibly()
         // "enableConcurrent" will be set to false after `onStop` is called, 
so we should check it
         // every time.
         if (!enableConcurrent && numActiveThreads != 1) {
@@ -174,37 +230,56 @@ private[celeborn] class Inbox(
         if (message == null) {
           numActiveThreads -= 1
           return
+        } else {
+          messageCount.decrementAndGet()
+          signalNotFull()
         }
+      } finally {
+        inboxLock.unlock()
       }
     }
   }
 
-  def post(message: InboxMessage): Unit = inbox.synchronized {
-    if (stopped) {
-      // We already put "OnStop" into "messages", so we should drop further 
messages
-      onDrop(message)
-    } else {
-      messages.add(message)
-      false
+  def post(message: InboxMessage): Unit = {
+    try {
+      inboxLock.lockInterruptibly()
+      if (stopped) {
+        // We already put "OnStop" into "messages", so we should drop further 
messages
+        onDrop(message)
+      } else {
+        addMessage(message)
+      }
+    } finally {
+      inboxLock.unlock()
     }
   }
 
-  def stop(): Unit = inbox.synchronized {
-    // The following codes should be in `synchronized` so that we can make 
sure "OnStop" is the last
-    // message
-    if (!stopped) {
-      // We should disable concurrent here. Then when RpcEndpoint.onStop is 
called, it's the only
-      // thread that is processing messages. So `RpcEndpoint.onStop` can 
release its resources
-      // safely.
-      enableConcurrent = false
-      stopped = true
-      messages.add(OnStop)
-      // Note: The concurrent events in messages will be processed one by one.
+  def stop(): Unit = {
+    try {
+      inboxLock.lockInterruptibly()
+      // The following codes should be in `synchronized` so that we can make 
sure "OnStop" is the last
+      // message
+      if (!stopped) {
+        // We should disable concurrent here. Then when RpcEndpoint.onStop is 
called, it's the only
+        // thread that is processing messages. So `RpcEndpoint.onStop` can 
release its resources
+        // safely.
+        enableConcurrent = false
+        stopped = true
+        addMessage(OnStop)
+        // Note: The concurrent events in messages will be processed one by 
one.
+      }
+    } finally {
+      inboxLock.unlock()
     }
   }
 
-  def isEmpty: Boolean = inbox.synchronized {
-    messages.isEmpty
+  def isEmpty: Boolean = {
+    try {
+      inboxLock.lockInterruptibly()
+      messages.isEmpty
+    } finally {
+      inboxLock.unlock()
+    }
   }
 
   /**
@@ -222,10 +297,13 @@ private[celeborn] class Inbox(
       endpoint: RpcEndpoint,
       endpointRefName: String)(action: => Unit): Unit = {
     def dealWithFatalError(fatal: Throwable): Unit = {
-      inbox.synchronized {
+      try {
+        inboxLock.lockInterruptibly()
         assert(numActiveThreads > 0, "The number of active threads should be 
positive.")
         // Should reduce the number of active threads before throw the error.
         numActiveThreads -= 1
+      } finally {
+        inboxLock.unlock()
       }
       logError(
         s"An error happened while processing message in the inbox for 
$endpointRefName",
@@ -254,8 +332,11 @@ private[celeborn] class Inbox(
 
   // exposed only for testing
   def getNumActiveThreads: Int = {
-    inbox.synchronized {
+    try {
+      inboxLock.lockInterruptibly()
       inbox.numActiveThreads
+    } finally {
+      inboxLock.unlock()
     }
   }
 }
diff --git 
a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala 
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala
index a8bc826dd..ab86a57e8 100644
--- 
a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala
+++ 
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala
@@ -21,18 +21,40 @@ import java.util.concurrent.{CountDownLatch, TimeUnit}
 import java.util.concurrent.atomic.AtomicInteger
 
 import org.mockito.Mockito._
+import org.scalatest.BeforeAndAfter
 
 import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.rpc.{RpcAddress, TestRpcEndpoint}
 
-class InboxSuite extends CelebornFunSuite {
+class InboxSuite extends CelebornFunSuite with BeforeAndAfter {
 
-  test("post") {
-    val endpoint = new TestRpcEndpoint
+  private var inbox: Inbox = _
+  private var endpoint: TestRpcEndpoint = _
+
+  def initInbox[T](
+      testRpcEndpoint: TestRpcEndpoint,
+      onDropOverride: Option[InboxMessage => T]): Inbox = {
     val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+    if (onDropOverride.isEmpty) {
+      new Inbox(rpcEnvRef, testRpcEndpoint, new CelebornConf())
+    } else {
+      new Inbox(rpcEnvRef, testRpcEndpoint, new CelebornConf()) {
+        override protected def onDrop(message: InboxMessage): Unit = {
+          onDropOverride.get(message)
+        }
+      }
+    }
+  }
+
+  before {
+    endpoint = new TestRpcEndpoint
+    inbox = initInbox(endpoint, None)
+  }
+
+  test("post") {
     val dispatcher = mock(classOf[Dispatcher])
 
-    val inbox = new Inbox(rpcEnvRef, endpoint)
     val message = OneWayMessage(null, "hi")
     inbox.post(message)
     inbox.process(dispatcher)
@@ -48,11 +70,8 @@ class InboxSuite extends CelebornFunSuite {
   }
 
   test("post: with reply") {
-    val endpoint = new TestRpcEndpoint
-    val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
     val dispatcher = mock(classOf[Dispatcher])
 
-    val inbox = new Inbox(rpcEnvRef, endpoint)
     val message = RpcMessage(null, "hi", null)
     inbox.post(message)
     inbox.process(dispatcher)
@@ -62,16 +81,15 @@ class InboxSuite extends CelebornFunSuite {
   }
 
   test("post: multiple threads") {
-    val endpoint = new TestRpcEndpoint
     val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
     val dispatcher = mock(classOf[Dispatcher])
 
     val numDroppedMessages = new AtomicInteger(0)
-    val inbox = new Inbox(rpcEnvRef, endpoint) {
-      override def onDrop(message: InboxMessage): Unit = {
-        numDroppedMessages.incrementAndGet()
-      }
+
+    val overrideOnDrop = (msg: InboxMessage) => {
+      numDroppedMessages.incrementAndGet()
     }
+    val inbox = initInbox(endpoint, Some(overrideOnDrop))
 
     val exitLatch = new CountDownLatch(10)
 
@@ -102,12 +120,9 @@ class InboxSuite extends CelebornFunSuite {
   }
 
   test("post: Associated") {
-    val endpoint = new TestRpcEndpoint
-    val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
     val dispatcher = mock(classOf[Dispatcher])
     val remoteAddress = RpcAddress("localhost", 11111)
 
-    val inbox = new Inbox(rpcEnvRef, endpoint)
     inbox.post(RemoteProcessConnected(remoteAddress))
     inbox.process(dispatcher)
 
@@ -115,13 +130,10 @@ class InboxSuite extends CelebornFunSuite {
   }
 
   test("post: Disassociated") {
-    val endpoint = new TestRpcEndpoint
-    val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
     val dispatcher = mock(classOf[Dispatcher])
 
     val remoteAddress = RpcAddress("localhost", 11111)
 
-    val inbox = new Inbox(rpcEnvRef, endpoint)
     inbox.post(RemoteProcessDisconnected(remoteAddress))
     inbox.process(dispatcher)
 
@@ -129,14 +141,11 @@ class InboxSuite extends CelebornFunSuite {
   }
 
   test("post: AssociationError") {
-    val endpoint = new TestRpcEndpoint
-    val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
     val dispatcher = mock(classOf[Dispatcher])
 
     val remoteAddress = RpcAddress("localhost", 11111)
     val cause = new RuntimeException("Oops")
 
-    val inbox = new Inbox(rpcEnvRef, endpoint)
     inbox.post(RemoteProcessConnectionError(cause, remoteAddress))
     inbox.process(dispatcher)
 
@@ -146,9 +155,8 @@ class InboxSuite extends CelebornFunSuite {
   test("should reduce the number of active threads when fatal error happens") {
     val endpoint = mock(classOf[TestRpcEndpoint])
     when(endpoint.receive).thenThrow(new OutOfMemoryError())
-    val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
     val dispatcher = mock(classOf[Dispatcher])
-    val inbox = new Inbox(rpcEnvRef, endpoint)
+    val inbox = initInbox(endpoint, None)
     inbox.post(OneWayMessage(null, "hi"))
     intercept[OutOfMemoryError] {
       inbox.process(dispatcher)
diff --git a/docs/configuration/network.md b/docs/configuration/network.md
index a295d5353..e0e3e8e11 100644
--- a/docs/configuration/network.md
+++ b/docs/configuration/network.md
@@ -50,6 +50,7 @@ license: |
 | celeborn.rpc.askTimeout | 60s | false | Timeout for RPC ask operations. It's 
recommended to set at least `240s` when `HDFS` is enabled in 
`celeborn.storage.activeTypes` | 0.2.0 |  | 
 | celeborn.rpc.connect.threads | 64 | false |  | 0.2.0 |  | 
 | celeborn.rpc.dispatcher.threads | 0 | false | Threads number of message 
dispatcher event loop. Default to 0, which is availableCore. | 0.3.0 | 
celeborn.rpc.dispatcher.numThreads | 
+| celeborn.rpc.inbox.capacity | 0 | false | Specifies size of the in memory 
bounded capacity. | 0.5.0 |  | 
 | celeborn.rpc.io.threads | &lt;undefined&gt; | false | Netty IO thread number 
of NettyRpcEnv to handle RPC request. The default threads number is the number 
of runtime available processors. | 0.2.0 |  | 
 | celeborn.rpc.lookupTimeout | 30s | false | Timeout for RPC lookup 
operations. | 0.2.0 |  | 
 | celeborn.shuffle.io.maxChunksBeingTransferred | &lt;undefined&gt; | false | 
The max number of chunks allowed to be transferred at the same time on shuffle 
service. Note that new incoming connections will be closed when the max number 
is hit. The client will retry according to the shuffle retry configs (see 
`celeborn.<module>.io.maxRetries` and `celeborn.<module>.io.retryWait`), if 
those limits are reached the task will fail with fetch failure. | 0.2.0 |  | 

Reply via email to