waitinfuture commented on code in PR #2366:
URL: https://github.com/apache/celeborn/pull/2366#discussion_r1564585403


##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -85,84 +96,133 @@ private[celeborn] class Inbox(
   private var numActiveThreads = 0
 
   // OnStart should be the first message to process
-  inbox.synchronized {
+  try {
+    inboxLock.lockInterruptibly()
     messages.add(OnStart)
+  } finally {
+    inboxLock.unlock()
+  }
+
+  def addMessage(message: InboxMessage): Unit = {
+    messages.add(message)
+    messageCount.incrementAndGet()
+    signalNotFull()

Review Comment:
   Is it necessary to call `signalNotFull()` when add message? Seems we should 
only call when poll msg.



##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -85,84 +96,133 @@ private[celeborn] class Inbox(
   private var numActiveThreads = 0
 
   // OnStart should be the first message to process
-  inbox.synchronized {
+  try {
+    inboxLock.lockInterruptibly()
     messages.add(OnStart)
+  } finally {
+    inboxLock.unlock()
+  }
+
+  def addMessage(message: InboxMessage): Unit = {
+    messages.add(message)
+    messageCount.incrementAndGet()
+    signalNotFull()
+    logDebug(s"queue length of ${messageCount.get()} ")
+  }
+
+  private def processInternal(dispatcher: Dispatcher, message: InboxMessage): 
Unit = {
+    message match {
+      case RpcMessage(_sender, content, context) =>
+        try {
+          endpoint.receiveAndReply(context).applyOrElse[Any, Unit](
+            content,
+            { msg =>
+              throw new CelebornException(s"Unsupported message $message from 
${_sender}")
+            })
+        } catch {
+          case e: Throwable =>
+            context.sendFailure(e)
+            // Throw the exception -- this exception will be caught by the 
safelyCall function.
+            // The endpoint's onError function will be called.
+            throw e
+        }
+
+      case OneWayMessage(_sender, content) =>
+        endpoint.receive.applyOrElse[Any, Unit](
+          content,
+          { msg =>
+            throw new CelebornException(s"Unsupported message $message from 
${_sender}")
+          })
+
+      case OnStart =>
+        endpoint.onStart()
+        if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
+          try {
+            inboxLock.lockInterruptibly()
+            if (!stopped) {
+              enableConcurrent = true
+            }
+          } finally {
+            inboxLock.unlock()
+          }
+        }
+
+      case OnStop =>
+        val activeThreads =
+          try {
+            inboxLock.lockInterruptibly()
+            inbox.numActiveThreads
+          } finally {
+            inboxLock.unlock()
+          }
+        assert(
+          activeThreads == 1,
+          s"There should be only a single active thread but found 
$activeThreads threads.")
+        dispatcher.removeRpcEndpointRef(endpoint)
+        endpoint.onStop()
+        assert(isEmpty, "OnStop should be the last message")
+
+      case RemoteProcessConnected(remoteAddress) =>
+        endpoint.onConnected(remoteAddress)
+
+      case RemoteProcessDisconnected(remoteAddress) =>
+        endpoint.onDisconnected(remoteAddress)
+
+      case RemoteProcessConnectionError(cause, remoteAddress) =>
+        endpoint.onNetworkError(cause, remoteAddress)
+
+      case other =>

Review Comment:
   I think it's unnecessary to match `other` here because all possible 
subclasses of `InboxMessage` are matched already.



##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -17,10 +17,14 @@
 
 package org.apache.celeborn.common.rpc.netty
 
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.locks.ReentrantLock
 import javax.annotation.concurrent.GuardedBy
 
+import scala.collection.mutable

Review Comment:
   nit: unused import



##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -85,84 +96,133 @@ private[celeborn] class Inbox(
   private var numActiveThreads = 0
 
   // OnStart should be the first message to process
-  inbox.synchronized {
+  try {
+    inboxLock.lockInterruptibly()
     messages.add(OnStart)
+  } finally {

Review Comment:
   Seems we need to increment `messageCount` after adding `OnStart`, because it 
will decrement when process the message.



##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -173,38 +233,59 @@ private[celeborn] class Inbox(
         message = messages.poll()
         if (message == null) {
           numActiveThreads -= 1
+          signalNotFull()

Review Comment:
   ditto



##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -85,84 +96,133 @@ private[celeborn] class Inbox(
   private var numActiveThreads = 0
 
   // OnStart should be the first message to process
-  inbox.synchronized {
+  try {
+    inboxLock.lockInterruptibly()
     messages.add(OnStart)
+  } finally {
+    inboxLock.unlock()
+  }
+
+  def addMessage(message: InboxMessage): Unit = {
+    messages.add(message)
+    messageCount.incrementAndGet()
+    signalNotFull()
+    logDebug(s"queue length of ${messageCount.get()} ")
+  }
+
+  private def processInternal(dispatcher: Dispatcher, message: InboxMessage): 
Unit = {
+    message match {
+      case RpcMessage(_sender, content, context) =>
+        try {
+          endpoint.receiveAndReply(context).applyOrElse[Any, Unit](
+            content,
+            { msg =>
+              throw new CelebornException(s"Unsupported message $message from 
${_sender}")
+            })
+        } catch {
+          case e: Throwable =>
+            context.sendFailure(e)
+            // Throw the exception -- this exception will be caught by the 
safelyCall function.
+            // The endpoint's onError function will be called.
+            throw e
+        }
+
+      case OneWayMessage(_sender, content) =>
+        endpoint.receive.applyOrElse[Any, Unit](
+          content,
+          { msg =>
+            throw new CelebornException(s"Unsupported message $message from 
${_sender}")
+          })
+
+      case OnStart =>
+        endpoint.onStart()
+        if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
+          try {
+            inboxLock.lockInterruptibly()
+            if (!stopped) {
+              enableConcurrent = true
+            }
+          } finally {
+            inboxLock.unlock()
+          }
+        }
+
+      case OnStop =>
+        val activeThreads =
+          try {
+            inboxLock.lockInterruptibly()
+            inbox.numActiveThreads
+          } finally {
+            inboxLock.unlock()
+          }
+        assert(
+          activeThreads == 1,
+          s"There should be only a single active thread but found 
$activeThreads threads.")
+        dispatcher.removeRpcEndpointRef(endpoint)
+        endpoint.onStop()
+        assert(isEmpty, "OnStop should be the last message")
+
+      case RemoteProcessConnected(remoteAddress) =>
+        endpoint.onConnected(remoteAddress)
+
+      case RemoteProcessDisconnected(remoteAddress) =>
+        endpoint.onDisconnected(remoteAddress)
+
+      case RemoteProcessConnectionError(cause, remoteAddress) =>
+        endpoint.onNetworkError(cause, remoteAddress)
+
+      case other =>
+        throw new IllegalStateException(s"unsupported message $other")
+    }
+  }
+
+  private[netty] def waitOnFull(): Unit = {
+    if (capacity > 0 && !stopped) {
+      try {
+        inboxLock.lockInterruptibly()
+        while (messageCount.get() == capacity) {
+          isFull.await()
+        }
+      } finally {
+        inboxLock.unlock()
+      }
+    }
+  }
+
+  private def signalNotFull(): Unit = {
+    // when this is called we assume putLock already being called
+    require(inboxLock.isHeldByCurrentThread, "cannot call signalNotFull 
without holding lock")
+    if (capacity > 0 && messageCount.get() < capacity) {
+      isFull.signal()
+    }
   }
 
-  /**
-   * Process stored messages.
-   */
   def process(dispatcher: Dispatcher): Unit = {
     var message: InboxMessage = null
-    inbox.synchronized {
+    try {
+      inboxLock.lockInterruptibly()
       if (!enableConcurrent && numActiveThreads != 0) {
         return
       }
       message = messages.poll()
       if (message != null) {
         numActiveThreads += 1
+        messageCount.decrementAndGet()
+        signalNotFull()
       } else {
+        signalNotFull()

Review Comment:
   Seems not to signal here because null means the queue is empty



##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -173,38 +233,59 @@ private[celeborn] class Inbox(
         message = messages.poll()
         if (message == null) {
           numActiveThreads -= 1
+          signalNotFull()
           return
+        } else {
+          messageCount.decrementAndGet()
+          signalNotFull()
         }
+      } finally {
+        inboxLock.unlock()
       }
     }
   }
 
-  def post(message: InboxMessage): Unit = inbox.synchronized {
-    if (stopped) {
-      // We already put "OnStop" into "messages", so we should drop further 
messages
-      onDrop(message)
-    } else {
-      messages.add(message)
-      false
+  def post(message: InboxMessage): Unit = {
+    try {
+      inboxLock.lockInterruptibly()
+      if (stopped) {
+        // We already put "OnStop" into "messages", so we should drop further 
messages
+        onDrop(message)
+      } else {
+        addMessage(message)
+      }
+      signalNotFull()

Review Comment:
   ditto



##########
common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala:
##########
@@ -85,84 +96,132 @@ private[celeborn] class Inbox(
   private var numActiveThreads = 0
 
   // OnStart should be the first message to process
-  inbox.synchronized {
+  try {
+    inboxLock.lockInterruptibly()
     messages.add(OnStart)
+  } finally {
+    inboxLock.unlock()
+  }
+
+  def addMessage(message: InboxMessage): Unit = {
+    messages.add(message)
+    messageCount.incrementAndGet()
+    signalNotFull()
+    logDebug(s"queue length of ${messageCount.get()} ")
+  }
+
+  private def processInternal(dispatcher: Dispatcher, message: InboxMessage): 
Unit = {
+    message match {
+      case RpcMessage(_sender, content, context) =>
+        try {
+          endpoint.receiveAndReply(context).applyOrElse[Any, Unit](
+            content,
+            { msg =>
+              throw new CelebornException(s"Unsupported message $message from 
${_sender}")
+            })
+        } catch {
+          case e: Throwable =>
+            context.sendFailure(e)
+            // Throw the exception -- this exception will be caught by the 
safelyCall function.
+            // The endpoint's onError function will be called.
+            throw e
+        }
+
+      case OneWayMessage(_sender, content) =>
+        endpoint.receive.applyOrElse[Any, Unit](
+          content,
+          { msg =>
+            throw new CelebornException(s"Unsupported message $message from 
${_sender}")
+          })
+
+      case OnStart =>
+        endpoint.onStart()
+        if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
+          try {
+            inboxLock.lockInterruptibly()
+            if (!stopped) {
+              enableConcurrent = true
+            }
+          } finally {
+            inboxLock.unlock()
+          }
+        }
+
+      case OnStop =>
+        val activeThreads =
+          try {
+            inboxLock.lockInterruptibly()
+            inbox.numActiveThreads
+          } finally {
+            inboxLock.unlock()
+          }
+        assert(
+          activeThreads == 1,
+          s"There should be only a single active thread but found 
$activeThreads threads.")
+        dispatcher.removeRpcEndpointRef(endpoint)
+        endpoint.onStop()
+        assert(isEmpty, "OnStop should be the last message")
+
+      case RemoteProcessConnected(remoteAddress) =>
+        endpoint.onConnected(remoteAddress)
+
+      case RemoteProcessDisconnected(remoteAddress) =>
+        endpoint.onDisconnected(remoteAddress)
+
+      case RemoteProcessConnectionError(cause, remoteAddress) =>
+        endpoint.onNetworkError(cause, remoteAddress)
+
+      case other =>
+        throw new IllegalStateException(s"unsupported message $other")
+    }
+  }
+
+  private[netty] def waitOnFull(): Unit = {
+    if (capacity > 0 && !stopped) {
+      try {
+        inboxLock.lockInterruptibly()
+        while (messageCount.get() == capacity) {
+          isFull.await()

Review Comment:
   Thanks for the explanation :) I think we should use `messageCount.get() >= 
capacity` here. For example capacity is 100 and size is 99, two threads 
concurrently calls `waitOnFull` and both returns immediately, then both of them 
will add message, after which the capacity exceeds 100, and `messageCount.get() 
== capacity` always returns false.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to