waitinfuture commented on code in PR #2366:
URL: https://github.com/apache/celeborn/pull/2366#discussion_r1564585403
##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -85,84 +96,133 @@ private[celeborn] class Inbox(
private var numActiveThreads = 0
// OnStart should be the first message to process
- inbox.synchronized {
+ try {
+ inboxLock.lockInterruptibly()
messages.add(OnStart)
+ } finally {
+ inboxLock.unlock()
+ }
+
+ def addMessage(message: InboxMessage): Unit = {
+ messages.add(message)
+ messageCount.incrementAndGet()
+ signalNotFull()
Review Comment:
Is it necessary to call `signalNotFull()` when add message? Seems we should
only call when poll msg.
##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -85,84 +96,133 @@ private[celeborn] class Inbox(
private var numActiveThreads = 0
// OnStart should be the first message to process
- inbox.synchronized {
+ try {
+ inboxLock.lockInterruptibly()
messages.add(OnStart)
+ } finally {
+ inboxLock.unlock()
+ }
+
+ def addMessage(message: InboxMessage): Unit = {
+ messages.add(message)
+ messageCount.incrementAndGet()
+ signalNotFull()
+ logDebug(s"queue length of ${messageCount.get()} ")
+ }
+
+ private def processInternal(dispatcher: Dispatcher, message: InboxMessage):
Unit = {
+ message match {
+ case RpcMessage(_sender, content, context) =>
+ try {
+ endpoint.receiveAndReply(context).applyOrElse[Any, Unit](
+ content,
+ { msg =>
+ throw new CelebornException(s"Unsupported message $message from
${_sender}")
+ })
+ } catch {
+ case e: Throwable =>
+ context.sendFailure(e)
+ // Throw the exception -- this exception will be caught by the
safelyCall function.
+ // The endpoint's onError function will be called.
+ throw e
+ }
+
+ case OneWayMessage(_sender, content) =>
+ endpoint.receive.applyOrElse[Any, Unit](
+ content,
+ { msg =>
+ throw new CelebornException(s"Unsupported message $message from
${_sender}")
+ })
+
+ case OnStart =>
+ endpoint.onStart()
+ if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
+ try {
+ inboxLock.lockInterruptibly()
+ if (!stopped) {
+ enableConcurrent = true
+ }
+ } finally {
+ inboxLock.unlock()
+ }
+ }
+
+ case OnStop =>
+ val activeThreads =
+ try {
+ inboxLock.lockInterruptibly()
+ inbox.numActiveThreads
+ } finally {
+ inboxLock.unlock()
+ }
+ assert(
+ activeThreads == 1,
+ s"There should be only a single active thread but found
$activeThreads threads.")
+ dispatcher.removeRpcEndpointRef(endpoint)
+ endpoint.onStop()
+ assert(isEmpty, "OnStop should be the last message")
+
+ case RemoteProcessConnected(remoteAddress) =>
+ endpoint.onConnected(remoteAddress)
+
+ case RemoteProcessDisconnected(remoteAddress) =>
+ endpoint.onDisconnected(remoteAddress)
+
+ case RemoteProcessConnectionError(cause, remoteAddress) =>
+ endpoint.onNetworkError(cause, remoteAddress)
+
+ case other =>
Review Comment:
I think it's unnecessary to match `other` here because all possible
subclasses of `InboxMessage` are matched already.
##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -17,10 +17,14 @@
package org.apache.celeborn.common.rpc.netty
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.locks.ReentrantLock
import javax.annotation.concurrent.GuardedBy
+import scala.collection.mutable
Review Comment:
nit: unused import
##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -85,84 +96,133 @@ private[celeborn] class Inbox(
private var numActiveThreads = 0
// OnStart should be the first message to process
- inbox.synchronized {
+ try {
+ inboxLock.lockInterruptibly()
messages.add(OnStart)
+ } finally {
Review Comment:
Seems we need to increment `messageCount` after adding `OnStart`, because it
will decrement when process the message.
##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -173,38 +233,59 @@ private[celeborn] class Inbox(
message = messages.poll()
if (message == null) {
numActiveThreads -= 1
+ signalNotFull()
Review Comment:
ditto
##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -85,84 +96,133 @@ private[celeborn] class Inbox(
private var numActiveThreads = 0
// OnStart should be the first message to process
- inbox.synchronized {
+ try {
+ inboxLock.lockInterruptibly()
messages.add(OnStart)
+ } finally {
+ inboxLock.unlock()
+ }
+
+ def addMessage(message: InboxMessage): Unit = {
+ messages.add(message)
+ messageCount.incrementAndGet()
+ signalNotFull()
+ logDebug(s"queue length of ${messageCount.get()} ")
+ }
+
+ private def processInternal(dispatcher: Dispatcher, message: InboxMessage):
Unit = {
+ message match {
+ case RpcMessage(_sender, content, context) =>
+ try {
+ endpoint.receiveAndReply(context).applyOrElse[Any, Unit](
+ content,
+ { msg =>
+ throw new CelebornException(s"Unsupported message $message from
${_sender}")
+ })
+ } catch {
+ case e: Throwable =>
+ context.sendFailure(e)
+ // Throw the exception -- this exception will be caught by the
safelyCall function.
+ // The endpoint's onError function will be called.
+ throw e
+ }
+
+ case OneWayMessage(_sender, content) =>
+ endpoint.receive.applyOrElse[Any, Unit](
+ content,
+ { msg =>
+ throw new CelebornException(s"Unsupported message $message from
${_sender}")
+ })
+
+ case OnStart =>
+ endpoint.onStart()
+ if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
+ try {
+ inboxLock.lockInterruptibly()
+ if (!stopped) {
+ enableConcurrent = true
+ }
+ } finally {
+ inboxLock.unlock()
+ }
+ }
+
+ case OnStop =>
+ val activeThreads =
+ try {
+ inboxLock.lockInterruptibly()
+ inbox.numActiveThreads
+ } finally {
+ inboxLock.unlock()
+ }
+ assert(
+ activeThreads == 1,
+ s"There should be only a single active thread but found
$activeThreads threads.")
+ dispatcher.removeRpcEndpointRef(endpoint)
+ endpoint.onStop()
+ assert(isEmpty, "OnStop should be the last message")
+
+ case RemoteProcessConnected(remoteAddress) =>
+ endpoint.onConnected(remoteAddress)
+
+ case RemoteProcessDisconnected(remoteAddress) =>
+ endpoint.onDisconnected(remoteAddress)
+
+ case RemoteProcessConnectionError(cause, remoteAddress) =>
+ endpoint.onNetworkError(cause, remoteAddress)
+
+ case other =>
+ throw new IllegalStateException(s"unsupported message $other")
+ }
+ }
+
+ private[netty] def waitOnFull(): Unit = {
+ if (capacity > 0 && !stopped) {
+ try {
+ inboxLock.lockInterruptibly()
+ while (messageCount.get() == capacity) {
+ isFull.await()
+ }
+ } finally {
+ inboxLock.unlock()
+ }
+ }
+ }
+
+ private def signalNotFull(): Unit = {
+ // when this is called we assume putLock already being called
+ require(inboxLock.isHeldByCurrentThread, "cannot call signalNotFull
without holding lock")
+ if (capacity > 0 && messageCount.get() < capacity) {
+ isFull.signal()
+ }
}
- /**
- * Process stored messages.
- */
def process(dispatcher: Dispatcher): Unit = {
var message: InboxMessage = null
- inbox.synchronized {
+ try {
+ inboxLock.lockInterruptibly()
if (!enableConcurrent && numActiveThreads != 0) {
return
}
message = messages.poll()
if (message != null) {
numActiveThreads += 1
+ messageCount.decrementAndGet()
+ signalNotFull()
} else {
+ signalNotFull()
Review Comment:
Seems not to signal here because null means the queue is empty
##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -173,38 +233,59 @@ private[celeborn] class Inbox(
message = messages.poll()
if (message == null) {
numActiveThreads -= 1
+ signalNotFull()
return
+ } else {
+ messageCount.decrementAndGet()
+ signalNotFull()
}
+ } finally {
+ inboxLock.unlock()
}
}
}
- def post(message: InboxMessage): Unit = inbox.synchronized {
- if (stopped) {
- // We already put "OnStop" into "messages", so we should drop further
messages
- onDrop(message)
- } else {
- messages.add(message)
- false
+ def post(message: InboxMessage): Unit = {
+ try {
+ inboxLock.lockInterruptibly()
+ if (stopped) {
+ // We already put "OnStop" into "messages", so we should drop further
messages
+ onDrop(message)
+ } else {
+ addMessage(message)
+ }
+ signalNotFull()
Review Comment:
ditto
##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -85,84 +96,132 @@ private[celeborn] class Inbox(
private var numActiveThreads = 0
// OnStart should be the first message to process
- inbox.synchronized {
+ try {
+ inboxLock.lockInterruptibly()
messages.add(OnStart)
+ } finally {
+ inboxLock.unlock()
+ }
+
+ def addMessage(message: InboxMessage): Unit = {
+ messages.add(message)
+ messageCount.incrementAndGet()
+ signalNotFull()
+ logDebug(s"queue length of ${messageCount.get()} ")
+ }
+
+ private def processInternal(dispatcher: Dispatcher, message: InboxMessage):
Unit = {
+ message match {
+ case RpcMessage(_sender, content, context) =>
+ try {
+ endpoint.receiveAndReply(context).applyOrElse[Any, Unit](
+ content,
+ { msg =>
+ throw new CelebornException(s"Unsupported message $message from
${_sender}")
+ })
+ } catch {
+ case e: Throwable =>
+ context.sendFailure(e)
+ // Throw the exception -- this exception will be caught by the
safelyCall function.
+ // The endpoint's onError function will be called.
+ throw e
+ }
+
+ case OneWayMessage(_sender, content) =>
+ endpoint.receive.applyOrElse[Any, Unit](
+ content,
+ { msg =>
+ throw new CelebornException(s"Unsupported message $message from
${_sender}")
+ })
+
+ case OnStart =>
+ endpoint.onStart()
+ if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
+ try {
+ inboxLock.lockInterruptibly()
+ if (!stopped) {
+ enableConcurrent = true
+ }
+ } finally {
+ inboxLock.unlock()
+ }
+ }
+
+ case OnStop =>
+ val activeThreads =
+ try {
+ inboxLock.lockInterruptibly()
+ inbox.numActiveThreads
+ } finally {
+ inboxLock.unlock()
+ }
+ assert(
+ activeThreads == 1,
+ s"There should be only a single active thread but found
$activeThreads threads.")
+ dispatcher.removeRpcEndpointRef(endpoint)
+ endpoint.onStop()
+ assert(isEmpty, "OnStop should be the last message")
+
+ case RemoteProcessConnected(remoteAddress) =>
+ endpoint.onConnected(remoteAddress)
+
+ case RemoteProcessDisconnected(remoteAddress) =>
+ endpoint.onDisconnected(remoteAddress)
+
+ case RemoteProcessConnectionError(cause, remoteAddress) =>
+ endpoint.onNetworkError(cause, remoteAddress)
+
+ case other =>
+ throw new IllegalStateException(s"unsupported message $other")
+ }
+ }
+
+ private[netty] def waitOnFull(): Unit = {
+ if (capacity > 0 && !stopped) {
+ try {
+ inboxLock.lockInterruptibly()
+ while (messageCount.get() == capacity) {
+ isFull.await()
Review Comment:
Thanks for the explanation :) I think we should use `messageCount.get() >=
capacity` here. For example capacity is 100 and size is 99, two threads
concurrently calls `waitOnFull` and both returns immediately, then both of them
will add message, after which the capacity exceeds 100, and `messageCount.get()
== capacity` always returns false.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]