This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 788b0c340 [CELEBORN-1135] Added tests for the RpcEnv and related 
classes
788b0c340 is described below

commit 788b0c340b6c5d31de76fc12983d60f71c3ef811
Author: Chandni Singh <[email protected]>
AuthorDate: Fri Nov 24 09:57:04 2023 +0800

    [CELEBORN-1135] Added tests for the RpcEnv and related classes
    
    ### What changes were proposed in this pull request?
    Added test suites for `RpcEnv`, `NettyRpcEnv`, and other related classes.
    These are copied over from Apache Spark. Some of the UTs in Apache Spark 
required changes in the source code like 
[SPARK-39468](https://issues.apache.org/jira/browse/SPARK-39468) which I didn't 
copy over.
    
    ### Why are the changes needed?
    The change adds unit tests.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Just adds UTs. The source code changes are minimal.
    
    Closes #2107 from otterc/CELEBORN-1135.
    
    Authored-by: Chandni Singh <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../apache/celeborn/common/rpc/netty/Inbox.scala   |  28 +-
 .../celeborn/common/rpc/RpcAddressSuite.scala      |  57 ++
 .../apache/celeborn/common/rpc/RpcEnvSuite.scala   | 855 +++++++++++++++++++++
 .../celeborn/common/rpc/TestRpcEndpoint.scala      | 124 +++
 .../celeborn/common/rpc/netty/InboxSuite.scala     | 158 ++++
 .../common/rpc/netty/NettyRpcAddressSuite.scala    |  34 +
 .../common/rpc/netty/NettyRpcEnvSuite.scala        | 138 ++++
 .../common/rpc/netty/NettyRpcHandlerSuite.scala    |  66 ++
 8 files changed, 1458 insertions(+), 2 deletions(-)

diff --git 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
index b24290fe7..09cdd08e2 100644
--- a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
@@ -106,7 +106,7 @@ private[celeborn] class Inbox(
       }
     }
     while (true) {
-      safelyCall(endpoint) {
+      safelyCall(endpoint, endpointRef.name) {
         message match {
           case RpcMessage(_sender, content, context) =>
             try {
@@ -218,7 +218,21 @@ private[celeborn] class Inbox(
   /**
    * Calls action closure, and calls the endpoint's onError function in the 
case of exceptions.
    */
-  private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = {
+  private def safelyCall(
+      endpoint: RpcEndpoint,
+      endpointRefName: String)(action: => Unit): Unit = {
+    def dealWithFatalError(fatal: Throwable): Unit = {
+      inbox.synchronized {
+        assert(numActiveThreads > 0, "The number of active threads should be 
positive.")
+        // Should reduce the number of active threads before throw the error.
+        numActiveThreads -= 1
+      }
+      logError(
+        s"An error happened while processing message in the inbox for 
$endpointRefName",
+        fatal)
+      throw fatal
+    }
+
     try action
     catch {
       case NonFatal(e) =>
@@ -230,8 +244,18 @@ private[celeborn] class Inbox(
             } else {
               logError("Ignoring error", ee)
             }
+          case fatal: Throwable =>
+            dealWithFatalError(fatal)
         }
+      case fatal: Throwable =>
+        dealWithFatalError(fatal)
     }
   }
 
+  // exposed only for testing
+  def getNumActiveThreads: Int = {
+    inbox.synchronized {
+      inbox.numActiveThreads
+    }
+  }
 }
diff --git 
a/common/src/test/scala/org/apache/celeborn/common/rpc/RpcAddressSuite.scala 
b/common/src/test/scala/org/apache/celeborn/common/rpc/RpcAddressSuite.scala
new file mode 100644
index 000000000..9bc322085
--- /dev/null
+++ b/common/src/test/scala/org/apache/celeborn/common/rpc/RpcAddressSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc
+
+import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.common.exception.CelebornException
+
+class RpcAddressSuite extends CelebornFunSuite {
+
+  test("hostPort") {
+    val address = RpcAddress("1.2.3.4", 1234)
+    assert(address.host === "1.2.3.4")
+    assert(address.port === 1234)
+    assert(address.hostPort === "1.2.3.4:1234")
+  }
+
+  test("fromCelebornURL") {
+    val address = RpcAddress.fromCelebornURL("celeborn://1.2.3.4:1234")
+    assert(address.host === "1.2.3.4")
+    assert(address.port === 1234)
+  }
+
+  test("fromCelebornURL: a typo url") {
+    val e = intercept[CelebornException] {
+      RpcAddress.fromCelebornURL("celeborn://1.2. 3.4:1234")
+    }
+    assert("Invalid master URL: celeborn://1.2. 3.4:1234" === e.getMessage)
+  }
+
+  test("fromCelebornURL: invalid scheme") {
+    val e = intercept[CelebornException] {
+      RpcAddress.fromCelebornURL("invalid://1.2.3.4:1234")
+    }
+    assert("Invalid master URL: invalid://1.2.3.4:1234" === e.getMessage)
+  }
+
+  test("toCelebornURL") {
+    val address = RpcAddress("1.2.3.4", 1234)
+    assert(address.toCelebornURL === "celeborn://1.2.3.4:1234")
+  }
+
+}
diff --git 
a/common/src/test/scala/org/apache/celeborn/common/rpc/RpcEnvSuite.scala 
b/common/src/test/scala/org/apache/celeborn/common/rpc/RpcEnvSuite.scala
new file mode 100644
index 000000000..4843cf744
--- /dev/null
+++ b/common/src/test/scala/org/apache/celeborn/common/rpc/RpcEnvSuite.scala
@@ -0,0 +1,855 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc
+
+import java.io.NotSerializableException
+import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit}
+
+import scala.collection.JavaConverters.collectionAsScalaIterableConverter
+import scala.collection.mutable
+import scala.concurrent.Await
+import scala.concurrent.duration._
+
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito.{mock, never, verify}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.exception.CelebornException
+import org.apache.celeborn.common.util.ThreadUtils
+
+/**
+ * Common tests for an RpcEnv implementation.
+ */
+abstract class RpcEnvSuite extends CelebornFunSuite {
+
+  var env: RpcEnv = _
+
+  def createCelebornConf(): CelebornConf = {
+    new CelebornConf()
+  }
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    val conf = createCelebornConf()
+    env = createRpcEnv(conf, "local", 0)
+
+  }
+
+  override def afterAll(): Unit = {
+    try {
+      if (env != null) {
+        env.shutdown()
+      }
+    } finally {
+      super.afterAll()
+    }
+  }
+
+  def createRpcEnv(conf: CelebornConf, name: String, port: Int, clientMode: 
Boolean = false): RpcEnv
+
+  test("send a message locally") {
+    @volatile var message: String = null
+    val rpcEndpointRef = env.setupEndpoint(
+      "send-locally",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receive = {
+          case msg: String => message = msg
+        }
+      })
+    rpcEndpointRef.send("hello")
+    eventually(timeout(5.seconds), interval(10.milliseconds)) {
+      assert("hello" === message)
+    }
+  }
+
+  test("send a message remotely") {
+    @volatile var message: String = null
+    // Set up a RpcEndpoint using env
+    env.setupEndpoint(
+      "send-remotely",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receive: PartialFunction[Any, Unit] = {
+          case msg: String => message = msg
+        }
+      })
+
+    val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0, 
clientMode = true)
+    // Use anotherEnv to find out the RpcEndpointRef
+    val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, 
"send-remotely")
+    try {
+      rpcEndpointRef.send("hello")
+      eventually(timeout(5.seconds), interval(10.milliseconds)) {
+        assert("hello" === message)
+      }
+    } finally {
+      anotherEnv.shutdown()
+      anotherEnv.awaitTermination()
+    }
+  }
+
+  test("send a RpcEndpointRef") {
+    val endpoint = new RpcEndpoint {
+      override val rpcEnv = env
+
+      override def receiveAndReply(context: RpcCallContext) = {
+        case "Hello" => context.reply(self)
+        case "Echo" => context.reply("Echo")
+      }
+    }
+    val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint)
+    val newRpcEndpointRef = rpcEndpointRef.askSync[RpcEndpointRef]("Hello")
+    val reply = newRpcEndpointRef.askSync[String]("Echo")
+    assert("Echo" === reply)
+  }
+
+  test("ask a message locally") {
+    val rpcEndpointRef = env.setupEndpoint(
+      "ask-locally",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+          case msg: String =>
+            context.reply(msg)
+        }
+      })
+    val reply = rpcEndpointRef.askSync[String]("hello")
+    assert("hello" === reply)
+  }
+
+  test("ask a message remotely") {
+    env.setupEndpoint(
+      "ask-remotely",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+          case msg: String =>
+            context.reply(msg)
+        }
+      })
+
+    val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0, 
clientMode = true)
+    // Use anotherEnv to find out the RpcEndpointRef
+    val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, 
"ask-remotely")
+    try {
+      val reply = rpcEndpointRef.askSync[String]("hello")
+      assert("hello" === reply)
+    } finally {
+      anotherEnv.shutdown()
+      anotherEnv.awaitTermination()
+    }
+  }
+
+  test("ask a message timeout") {
+    env.setupEndpoint(
+      "ask-timeout",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+          case msg: String =>
+            Thread.sleep(100)
+            context.reply(msg)
+        }
+      })
+
+    val conf = createCelebornConf()
+    val shortProp = "celeborn.rpc.short.timeout"
+    val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true)
+    // Use anotherEnv to find out the RpcEndpointRef
+    val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, 
"ask-timeout")
+    try {
+      val e = intercept[RpcTimeoutException] {
+        rpcEndpointRef.askSync[String]("hello", new RpcTimeout(1.millisecond, 
shortProp))
+      }
+      // The Celeborn exception cause should be a RpcTimeoutException with 
message indicating the
+      // controlling timeout property
+      assert(e.isInstanceOf[RpcTimeoutException])
+      assert(e.getMessage.contains(shortProp))
+    } finally {
+      anotherEnv.shutdown()
+      anotherEnv.awaitTermination()
+    }
+  }
+
+  test("onStart and onStop") {
+    val stopLatch = new CountDownLatch(1)
+    val calledMethods = mutable.ArrayBuffer[String]()
+
+    val endpoint = new RpcEndpoint {
+      override val rpcEnv = env
+
+      override def onStart(): Unit = {
+        calledMethods += "start"
+      }
+
+      override def receive: PartialFunction[Any, Unit] = {
+        case msg: String =>
+      }
+
+      override def onStop(): Unit = {
+        calledMethods += "stop"
+        stopLatch.countDown()
+      }
+    }
+    val rpcEndpointRef = env.setupEndpoint("start-stop-test", endpoint)
+    env.stop(rpcEndpointRef)
+    stopLatch.await(10, TimeUnit.SECONDS)
+    assert(List("start", "stop") === calledMethods)
+  }
+
+  test("onError: error in onStart") {
+    @volatile var e: Throwable = null
+    env.setupEndpoint(
+      "onError-onStart",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def onStart(): Unit = {
+          throw new RuntimeException("Oops!")
+        }
+
+        override def receive: PartialFunction[Any, Unit] = {
+          case m =>
+        }
+
+        override def onError(cause: Throwable): Unit = {
+          e = cause
+        }
+      })
+
+    eventually(timeout(5.seconds), interval(10.milliseconds)) {
+      assert(e.getMessage === "Oops!")
+    }
+  }
+
+  test("onError: error in onStop") {
+    @volatile var e: Throwable = null
+    val endpointRef = env.setupEndpoint(
+      "onError-onStop",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receive: PartialFunction[Any, Unit] = {
+          case m =>
+        }
+
+        override def onError(cause: Throwable): Unit = {
+          e = cause
+        }
+
+        override def onStop(): Unit = {
+          throw new RuntimeException("Oops!")
+        }
+      })
+
+    env.stop(endpointRef)
+
+    eventually(timeout(5.seconds), interval(10.milliseconds)) {
+      assert(e.getMessage === "Oops!")
+    }
+  }
+
+  test("onError: error in receive") {
+    @volatile var e: Throwable = null
+    val endpointRef = env.setupEndpoint(
+      "onError-receive",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receive: PartialFunction[Any, Unit] = {
+          case m => throw new RuntimeException("Oops!")
+        }
+
+        override def onError(cause: Throwable): Unit = {
+          e = cause
+        }
+      })
+
+    endpointRef.send("Foo")
+
+    eventually(timeout(5.seconds), interval(10.milliseconds)) {
+      assert(e.getMessage === "Oops!")
+    }
+  }
+
+  test("self: call in onStart") {
+    @volatile var callSelfSuccessfully = false
+
+    env.setupEndpoint(
+      "self-onStart",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def onStart(): Unit = {
+          self
+          callSelfSuccessfully = true
+        }
+
+        override def receive: PartialFunction[Any, Unit] = {
+          case m =>
+        }
+      })
+
+    eventually(timeout(5.seconds), interval(10.milliseconds)) {
+      // Calling `self` in `onStart` is fine
+      assert(callSelfSuccessfully)
+    }
+  }
+
+  test("self: call in receive") {
+    @volatile var callSelfSuccessfully = false
+
+    val endpointRef = env.setupEndpoint(
+      "self-receive",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receive: PartialFunction[Any, Unit] = {
+          case m =>
+            self
+            callSelfSuccessfully = true
+        }
+      })
+
+    endpointRef.send("Foo")
+
+    eventually(timeout(5.seconds), interval(10.milliseconds)) {
+      // Calling `self` in `receive` is fine
+      assert(callSelfSuccessfully)
+    }
+  }
+
+  test("self: call in onStop") {
+    @volatile var selfOption: Option[RpcEndpointRef] = null
+
+    val endpointRef = env.setupEndpoint(
+      "self-onStop",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receive: PartialFunction[Any, Unit] = {
+          case m =>
+        }
+
+        override def onStop(): Unit = {
+          selfOption = Option(self)
+        }
+      })
+
+    env.stop(endpointRef)
+
+    eventually(timeout(5.seconds), interval(10.milliseconds)) {
+      // Calling `self` in `onStop` will return null, so selfOption will be 
None
+      assert(selfOption.isEmpty)
+    }
+  }
+
+  test("call receive in sequence") {
+    // If a RpcEnv implementation breaks the `receive` contract, hope this 
test can expose it
+    for (i <- 0 until 100) {
+      @volatile var result = 0
+      val endpointRef = env.setupEndpoint(
+        s"receive-in-sequence-$i",
+        new ThreadSafeRpcEndpoint {
+          override val rpcEnv = env
+
+          override def receive: PartialFunction[Any, Unit] = {
+            case m => result += 1
+          }
+
+        })
+
+      (0 until 10) foreach { _ =>
+        new Thread {
+          override def run(): Unit = {
+            (0 until 100) foreach { _ =>
+              endpointRef.send("Hello")
+            }
+          }
+        }.start()
+      }
+
+      eventually(timeout(5.seconds), interval(5.milliseconds)) {
+        assert(result == 1000)
+      }
+
+      env.stop(endpointRef)
+    }
+  }
+
+  test("stop(RpcEndpointRef) reentrant") {
+    @volatile var onStopCount = 0
+    val endpointRef = env.setupEndpoint(
+      "stop-reentrant",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receive: PartialFunction[Any, Unit] = {
+          case m =>
+        }
+
+        override def onStop(): Unit = {
+          onStopCount += 1
+        }
+      })
+
+    env.stop(endpointRef)
+    env.stop(endpointRef)
+
+    eventually(timeout(5.seconds), interval(5.milliseconds)) {
+      // Calling stop twice should only trigger onStop once.
+      assert(onStopCount == 1)
+    }
+  }
+
+  test("sendWithReply") {
+    val endpointRef = env.setupEndpoint(
+      "sendWithReply",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+          case m => context.reply("ack")
+        }
+      })
+
+    val f = endpointRef.ask[String]("Hi")
+    val ack = ThreadUtils.awaitResult(f, 5.seconds)
+    assert("ack" === ack)
+
+    env.stop(endpointRef)
+  }
+
+  test("sendWithReply: remotely") {
+    env.setupEndpoint(
+      "sendWithReply-remotely",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+          case m => context.reply("ack")
+        }
+      })
+
+    val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0, 
clientMode = true)
+    // Use anotherEnv to find out the RpcEndpointRef
+    val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, 
"sendWithReply-remotely")
+    try {
+      val f = rpcEndpointRef.ask[String]("hello")
+      val ack = ThreadUtils.awaitResult(f, 5.seconds)
+      assert("ack" === ack)
+    } finally {
+      anotherEnv.shutdown()
+      anotherEnv.awaitTermination()
+    }
+  }
+
+  test("sendWithReply: error") {
+    val endpointRef = env.setupEndpoint(
+      "sendWithReply-error",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+          case m => context.sendFailure(new CelebornException("Oops"))
+        }
+      })
+
+    val f = endpointRef.ask[String]("Hi")
+    val e = intercept[CelebornException] {
+      ThreadUtils.awaitResult(f, 5.seconds)
+    }
+    assert("Oops" === e.getCause.getMessage)
+
+    env.stop(endpointRef)
+  }
+
+  test("sendWithReply: remotely error") {
+    env.setupEndpoint(
+      "sendWithReply-remotely-error",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+          case msg: String => context.sendFailure(new 
CelebornException("Oops"))
+        }
+      })
+
+    val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0, 
clientMode = true)
+    // Use anotherEnv to find out the RpcEndpointRef
+    val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, 
"sendWithReply-remotely-error")
+    try {
+      val f = rpcEndpointRef.ask[String]("hello")
+      val e = intercept[CelebornException] {
+        ThreadUtils.awaitResult(f, 5.seconds)
+      }
+      assert("Oops" === e.getCause.getMessage)
+    } finally {
+      anotherEnv.shutdown()
+      anotherEnv.awaitTermination()
+    }
+  }
+
+  /**
+   * Setup an [[RpcEndpoint]] to collect all network events.
+   *
+   * @return the [[RpcEndpointRef]] and a `ConcurrentLinkedQueue` that 
contains network events.
+   */
+  private def setupNetworkEndpoint(
+      _env: RpcEnv,
+      name: String): (RpcEndpointRef, ConcurrentLinkedQueue[(Any, Any)]) = {
+    val events = new ConcurrentLinkedQueue[(Any, Any)]
+    val ref = _env.setupEndpoint(
+      "network-events-non-client",
+      new ThreadSafeRpcEndpoint {
+        override val rpcEnv = _env
+
+        override def receive: PartialFunction[Any, Unit] = {
+          case "hello" =>
+          case m => events.add("receive" -> m)
+        }
+
+        override def onConnected(remoteAddress: RpcAddress): Unit = {
+          events.add("onConnected" -> remoteAddress)
+        }
+
+        override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+          events.add("onDisconnected" -> remoteAddress)
+        }
+
+        override def onNetworkError(cause: Throwable, remoteAddress: 
RpcAddress): Unit = {
+          events.add("onNetworkError" -> remoteAddress)
+        }
+
+      })
+    (ref, events)
+  }
+
+  test("network events in sever RpcEnv when another RpcEnv is in server mode") 
{
+    val serverEnv1 = createRpcEnv(createCelebornConf(), "server1", 0, 
clientMode = false)
+    val serverEnv2 = createRpcEnv(createCelebornConf(), "server2", 0, 
clientMode = false)
+    val (_, events) = setupNetworkEndpoint(serverEnv1, "network-events")
+    val (serverRef2, _) = setupNetworkEndpoint(serverEnv2, "network-events")
+    try {
+      val serverRefInServer2 = serverEnv1.setupEndpointRef(serverRef2.address, 
serverRef2.name)
+      // Send a message to set up the connection
+      serverRefInServer2.send("hello")
+
+      eventually(timeout(5.seconds), interval(5.milliseconds)) {
+        assert(events.contains(("onConnected", serverEnv2.address)))
+      }
+
+      serverEnv2.shutdown()
+      serverEnv2.awaitTermination()
+
+      eventually(timeout(5.seconds), interval(5.milliseconds)) {
+        assert(events.contains(("onConnected", serverEnv2.address)))
+        assert(events.contains(("onDisconnected", serverEnv2.address)))
+      }
+    } finally {
+      serverEnv1.shutdown()
+      serverEnv2.shutdown()
+      serverEnv1.awaitTermination()
+      serverEnv2.awaitTermination()
+    }
+  }
+
+  test("network events in sever RpcEnv when another RpcEnv is in client mode") 
{
+    val serverEnv = createRpcEnv(createCelebornConf(), "server", 0, clientMode 
= false)
+    val (serverRef, events) = setupNetworkEndpoint(serverEnv, "network-events")
+    val clientEnv = createRpcEnv(createCelebornConf(), "client", 0, clientMode 
= true)
+    try {
+      val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address, 
serverRef.name)
+      // Send a message to set up the connection
+      serverRefInClient.send("hello")
+
+      eventually(timeout(5.seconds), interval(5.milliseconds)) {
+        // We don't know the exact client address but at least we can verify 
the message type
+        assert(events.asScala.map(_._1).exists(_ == "onConnected"))
+      }
+
+      clientEnv.shutdown()
+      clientEnv.awaitTermination()
+
+      eventually(timeout(5.seconds), interval(5.milliseconds)) {
+        // We don't know the exact client address but at least we can verify 
the message type
+        assert(events.asScala.map(_._1).exists(_ == "onConnected"))
+        assert(events.asScala.map(_._1).exists(_ == "onDisconnected"))
+      }
+    } finally {
+      clientEnv.shutdown()
+      serverEnv.shutdown()
+      clientEnv.awaitTermination()
+      serverEnv.awaitTermination()
+    }
+  }
+
+  test("network events in client RpcEnv when another RpcEnv is in server 
mode") {
+    val clientEnv = createRpcEnv(createCelebornConf(), "client", 0, clientMode 
= true)
+    val serverEnv = createRpcEnv(createCelebornConf(), "server", 0, clientMode 
= false)
+    val (_, events) = setupNetworkEndpoint(clientEnv, "network-events")
+    val (serverRef, _) = setupNetworkEndpoint(serverEnv, "network-events")
+    try {
+      val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address, 
serverRef.name)
+      // Send a message to set up the connection
+      serverRefInClient.send("hello")
+
+      eventually(timeout(5.seconds), interval(5.milliseconds)) {
+        assert(events.contains(("onConnected", serverEnv.address)))
+      }
+
+      serverEnv.shutdown()
+      serverEnv.awaitTermination()
+
+      eventually(timeout(5.seconds), interval(5.milliseconds)) {
+        assert(events.contains(("onConnected", serverEnv.address)))
+        assert(events.contains(("onDisconnected", serverEnv.address)))
+      }
+    } finally {
+      clientEnv.shutdown()
+      serverEnv.shutdown()
+      clientEnv.awaitTermination()
+      serverEnv.awaitTermination()
+    }
+  }
+
+  test("sendWithReply: unserializable error") {
+    env.setupEndpoint(
+      "sendWithReply-unserializable-error",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+          case msg: String => context.sendFailure(new UnserializableException)
+        }
+      })
+
+    val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0, 
clientMode = true)
+    // Use anotherEnv to find out the RpcEndpointRef
+    val rpcEndpointRef =
+      anotherEnv.setupEndpointRef(env.address, 
"sendWithReply-unserializable-error")
+    try {
+      val f = rpcEndpointRef.ask[String]("hello")
+      val e = intercept[CelebornException] {
+        ThreadUtils.awaitResult(f, 1.second)
+      }
+      assert(e.getCause.isInstanceOf[NotSerializableException])
+    } finally {
+      anotherEnv.shutdown()
+      anotherEnv.awaitTermination()
+    }
+  }
+
+  test("port conflict") {
+    val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 
env.address.port)
+    try {
+      assert(anotherEnv.address.port != env.address.port)
+    } finally {
+      anotherEnv.shutdown()
+      anotherEnv.awaitTermination()
+    }
+  }
+
+  private def testSend(conf: CelebornConf): Unit = {
+    val localEnv = createRpcEnv(conf, "authentication-local", 0)
+    val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode 
= true)
+
+    try {
+      @volatile var message: String = null
+      localEnv.setupEndpoint(
+        "send-authentication",
+        new RpcEndpoint {
+          override val rpcEnv = localEnv
+
+          override def receive: PartialFunction[Any, Unit] = {
+            case msg: String => message = msg
+          }
+        })
+      val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, 
"send-authentication")
+      rpcEndpointRef.send("hello")
+      eventually(timeout(5.seconds), interval(10.milliseconds)) {
+        assert("hello" === message)
+      }
+    } finally {
+      localEnv.shutdown()
+      localEnv.awaitTermination()
+      remoteEnv.shutdown()
+      remoteEnv.awaitTermination()
+    }
+  }
+
+  private def testAsk(conf: CelebornConf): Unit = {
+    val localEnv = createRpcEnv(conf, "authentication-local", 0)
+    val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode 
= true)
+
+    try {
+      localEnv.setupEndpoint(
+        "ask-authentication",
+        new RpcEndpoint {
+          override val rpcEnv = localEnv
+
+          override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+            case msg: String =>
+              context.reply(msg)
+          }
+        })
+      val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, 
"ask-authentication")
+      val reply = rpcEndpointRef.askSync[String]("hello")
+      assert("hello" === reply)
+    } finally {
+      localEnv.shutdown()
+      localEnv.awaitTermination()
+      remoteEnv.shutdown()
+      remoteEnv.awaitTermination()
+    }
+  }
+
+  test("construct RpcTimeout with conf property") {
+    val conf = new CelebornConf()
+
+    val testProp = "celeborn.ask.test.timeout"
+    val testDurationSeconds = 30
+    val secondaryProp = "celeborn.ask.secondary.timeout"
+
+    conf.set(testProp, s"${testDurationSeconds}s")
+    conf.set(secondaryProp, "100s")
+
+    // Construct RpcTimeout with a single property
+    val rt1 = RpcTimeout(conf, testProp)
+    assert(testDurationSeconds === rt1.duration.toSeconds)
+
+    // Construct RpcTimeout with prioritized list of properties
+    val rt2 = RpcTimeout(conf, Seq("celeborn.ask.invalid.timeout", testProp, 
secondaryProp), "1s")
+    assert(testDurationSeconds === rt2.duration.toSeconds)
+
+    // Construct RpcTimeout with default value,
+    val defaultProp = "celeborn.ask.default.timeout"
+    val defaultDurationSeconds = 1
+    val rt3 = RpcTimeout(conf, Seq(defaultProp), 
defaultDurationSeconds.toString + "s")
+    assert(defaultDurationSeconds === rt3.duration.toSeconds)
+    assert(rt3.timeoutProp.contains(defaultProp))
+
+    // Try to construct RpcTimeout with an unconfigured property
+    intercept[NoSuchElementException] {
+      RpcTimeout(conf, "celeborn.ask.invalid.timeout")
+    }
+  }
+
+  test("ask a message timeout on Future using RpcTimeout") {
+    case class NeverReply(msg: String)
+
+    val rpcEndpointRef = env.setupEndpoint(
+      "ask-future",
+      new RpcEndpoint {
+        override val rpcEnv = env
+
+        override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+          case msg: String => context.reply(msg)
+          case _: NeverReply =>
+        }
+      })
+
+    val longTimeout = new RpcTimeout(1.second, "celeborn.rpc.long.timeout")
+    val shortTimeout = new RpcTimeout(10.milliseconds, 
"celeborn.rpc.short.timeout")
+
+    // Ask with immediate response, should complete successfully
+    val fut1 = rpcEndpointRef.ask[String]("hello", longTimeout)
+    val reply1 = longTimeout.awaitResult(fut1)
+    assert("hello" === reply1)
+
+    // Ask with a delayed response and wait for response immediately that 
should timeout
+    val fut2 = rpcEndpointRef.ask[String](NeverReply("doh"), shortTimeout)
+    val reply2 =
+      intercept[RpcTimeoutException] {
+        shortTimeout.awaitResult(fut2)
+      }.getMessage
+
+    // RpcTimeout.awaitResult should have added the property to the 
TimeoutException message
+    assert(reply2.contains(shortTimeout.timeoutProp))
+
+    // Ask with delayed response and allow the Future to timeout before 
ThreadUtils.awaitResult
+    val fut3 = rpcEndpointRef.ask[String](NeverReply("goodbye"), shortTimeout)
+
+    // scalastyle:off awaitresult
+    // Allow future to complete with failure using plain Await.result, this 
will return
+    // once the future is complete to verify addMessageIfTimeout was invoked
+    val reply3 =
+      intercept[RpcTimeoutException] {
+        Await.result(fut3, 2.seconds)
+      }.getMessage
+    // scalastyle:on awaitresult
+
+    // When the future timed out, the recover callback should have used
+    // RpcTimeout.addMessageIfTimeout to add the property to the 
TimeoutException message
+    assert(reply3.contains(shortTimeout.timeoutProp))
+
+    // Use RpcTimeout.awaitResult to process Future, since it has already 
failed with
+    // RpcTimeoutException, the same RpcTimeoutException should be thrown
+    val reply4 =
+      intercept[RpcTimeoutException] {
+        shortTimeout.awaitResult(fut3)
+      }.getMessage
+
+    // Ensure description is not in message twice after addMessageIfTimeout 
and awaitResult
+    assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1)
+  }
+
+  test("RpcEnv.shutdown should not fire onDisconnected events") {
+    env.setupEndpoint(
+      "test_ep_11212023",
+      new RpcEndpoint {
+        override val rpcEnv: RpcEnv = env
+
+        override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+          case m => context.reply(m)
+        }
+      })
+
+    val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0)
+    val endpoint = mock(classOf[RpcEndpoint])
+    anotherEnv.setupEndpoint("test_ep_11212023", endpoint)
+
+    val ref = anotherEnv.setupEndpointRef(env.address, "test_ep_11212023")
+    // Make sure the connect is set up
+    assert(ref.askSync[String]("hello") === "hello")
+    anotherEnv.shutdown()
+    anotherEnv.awaitTermination()
+
+    env.stop(ref)
+
+    verify(endpoint).onStop()
+    verify(endpoint, never()).onDisconnected(any())
+    verify(endpoint, never()).onNetworkError(any(), any())
+  }
+}
+
+case class Register(ref: RpcEndpointRef)
+
+class UnserializableClass
+
+class UnserializableException extends Exception {
+  private val unserializableField = new UnserializableClass
+}
diff --git 
a/common/src/test/scala/org/apache/celeborn/common/rpc/TestRpcEndpoint.scala 
b/common/src/test/scala/org/apache/celeborn/common/rpc/TestRpcEndpoint.scala
new file mode 100644
index 000000000..9d343b037
--- /dev/null
+++ b/common/src/test/scala/org/apache/celeborn/common/rpc/TestRpcEndpoint.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalactic.TripleEquals
+import org.scalatest.Assertions._
+
+class TestRpcEndpoint extends ThreadSafeRpcEndpoint with TripleEquals {
+
+  override val rpcEnv: RpcEnv = null
+
+  @volatile private var receiveMessages = ArrayBuffer[Any]()
+
+  @volatile private var receiveAndReplyMessages = ArrayBuffer[Any]()
+
+  @volatile private var onConnectedMessages = ArrayBuffer[RpcAddress]()
+
+  @volatile private var onDisconnectedMessages = ArrayBuffer[RpcAddress]()
+
+  @volatile private var onNetworkErrorMessages = ArrayBuffer[(Throwable, 
RpcAddress)]()
+
+  @volatile private var started = false
+
+  @volatile private var stopped = false
+
+  override def receive: PartialFunction[Any, Unit] = {
+    case message: Any => receiveMessages += message
+  }
+
+  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, 
Unit] = {
+    case message: Any => receiveAndReplyMessages += message
+  }
+
+  override def onConnected(remoteAddress: RpcAddress): Unit = {
+    onConnectedMessages += remoteAddress
+  }
+
+  /**
+   * Invoked when some network error happens in the connection between the 
current node and
+   * `remoteAddress`.
+   */
+  override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): 
Unit = {
+    onNetworkErrorMessages += cause -> remoteAddress
+  }
+
+  override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+    onDisconnectedMessages += remoteAddress
+  }
+
+  def numReceiveMessages: Int = receiveMessages.size
+
+  override def onStart(): Unit = {
+    started = true
+  }
+
+  override def onStop(): Unit = {
+    stopped = true
+  }
+
+  def verifyStarted(): Unit = {
+    assert(started, "RpcEndpoint is not started")
+  }
+
+  def verifyStopped(): Unit = {
+    assert(stopped, "RpcEndpoint is not stopped")
+  }
+
+  def verifyReceiveMessages(expected: Seq[Any]): Unit = {
+    assert(receiveMessages === expected)
+  }
+
+  def verifySingleReceiveMessage(message: Any): Unit = {
+    verifyReceiveMessages(List(message))
+  }
+
+  def verifyReceiveAndReplyMessages(expected: Seq[Any]): Unit = {
+    assert(receiveAndReplyMessages === expected)
+  }
+
+  def verifySingleReceiveAndReplyMessage(message: Any): Unit = {
+    verifyReceiveAndReplyMessages(List(message))
+  }
+
+  def verifySingleOnConnectedMessage(remoteAddress: RpcAddress): Unit = {
+    verifyOnConnectedMessages(List(remoteAddress))
+  }
+
+  def verifyOnConnectedMessages(expected: Seq[RpcAddress]): Unit = {
+    assert(onConnectedMessages === expected)
+  }
+
+  def verifySingleOnDisconnectedMessage(remoteAddress: RpcAddress): Unit = {
+    verifyOnDisconnectedMessages(List(remoteAddress))
+  }
+
+  def verifyOnDisconnectedMessages(expected: Seq[RpcAddress]): Unit = {
+    assert(onDisconnectedMessages === expected)
+  }
+
+  def verifySingleOnNetworkErrorMessage(cause: Throwable, remoteAddress: 
RpcAddress): Unit = {
+    verifyOnNetworkErrorMessages(List(cause -> remoteAddress))
+  }
+
+  def verifyOnNetworkErrorMessages(expected: Seq[(Throwable, RpcAddress)]): 
Unit = {
+    assert(onNetworkErrorMessages === expected)
+  }
+}
diff --git 
a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala 
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala
new file mode 100644
index 000000000..a8bc826dd
--- /dev/null
+++ 
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc.netty
+
+import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.mockito.Mockito._
+
+import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.common.rpc.{RpcAddress, TestRpcEndpoint}
+
+class InboxSuite extends CelebornFunSuite {
+
+  test("post") {
+    val endpoint = new TestRpcEndpoint
+    val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+    val dispatcher = mock(classOf[Dispatcher])
+
+    val inbox = new Inbox(rpcEnvRef, endpoint)
+    val message = OneWayMessage(null, "hi")
+    inbox.post(message)
+    inbox.process(dispatcher)
+    assert(inbox.isEmpty)
+
+    endpoint.verifySingleReceiveMessage("hi")
+
+    inbox.stop()
+    inbox.process(dispatcher)
+    assert(inbox.isEmpty)
+    endpoint.verifyStarted()
+    endpoint.verifyStopped()
+  }
+
+  test("post: with reply") {
+    val endpoint = new TestRpcEndpoint
+    val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+    val dispatcher = mock(classOf[Dispatcher])
+
+    val inbox = new Inbox(rpcEnvRef, endpoint)
+    val message = RpcMessage(null, "hi", null)
+    inbox.post(message)
+    inbox.process(dispatcher)
+    assert(inbox.isEmpty)
+
+    endpoint.verifySingleReceiveAndReplyMessage("hi")
+  }
+
+  test("post: multiple threads") {
+    val endpoint = new TestRpcEndpoint
+    val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+    val dispatcher = mock(classOf[Dispatcher])
+
+    val numDroppedMessages = new AtomicInteger(0)
+    val inbox = new Inbox(rpcEnvRef, endpoint) {
+      override def onDrop(message: InboxMessage): Unit = {
+        numDroppedMessages.incrementAndGet()
+      }
+    }
+
+    val exitLatch = new CountDownLatch(10)
+
+    for (_ <- 0 until 10) {
+      new Thread {
+        override def run(): Unit = {
+          for (_ <- 0 until 100) {
+            val message = OneWayMessage(null, "hi")
+            inbox.post(message)
+          }
+          exitLatch.countDown()
+        }
+      }.start()
+    }
+    // Try to process some messages
+    inbox.process(dispatcher)
+    inbox.stop()
+    // After `stop` is called, further messages will be dropped. However, 
while `stop` is called,
+    // some messages may be post to Inbox, so process them here.
+    inbox.process(dispatcher)
+    assert(inbox.isEmpty)
+
+    exitLatch.await(30, TimeUnit.SECONDS)
+
+    assert(1000 === endpoint.numReceiveMessages + numDroppedMessages.get)
+    endpoint.verifyStarted()
+    endpoint.verifyStopped()
+  }
+
+  test("post: Associated") {
+    val endpoint = new TestRpcEndpoint
+    val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+    val dispatcher = mock(classOf[Dispatcher])
+    val remoteAddress = RpcAddress("localhost", 11111)
+
+    val inbox = new Inbox(rpcEnvRef, endpoint)
+    inbox.post(RemoteProcessConnected(remoteAddress))
+    inbox.process(dispatcher)
+
+    endpoint.verifySingleOnConnectedMessage(remoteAddress)
+  }
+
+  test("post: Disassociated") {
+    val endpoint = new TestRpcEndpoint
+    val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+    val dispatcher = mock(classOf[Dispatcher])
+
+    val remoteAddress = RpcAddress("localhost", 11111)
+
+    val inbox = new Inbox(rpcEnvRef, endpoint)
+    inbox.post(RemoteProcessDisconnected(remoteAddress))
+    inbox.process(dispatcher)
+
+    endpoint.verifySingleOnDisconnectedMessage(remoteAddress)
+  }
+
+  test("post: AssociationError") {
+    val endpoint = new TestRpcEndpoint
+    val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+    val dispatcher = mock(classOf[Dispatcher])
+
+    val remoteAddress = RpcAddress("localhost", 11111)
+    val cause = new RuntimeException("Oops")
+
+    val inbox = new Inbox(rpcEnvRef, endpoint)
+    inbox.post(RemoteProcessConnectionError(cause, remoteAddress))
+    inbox.process(dispatcher)
+
+    endpoint.verifySingleOnNetworkErrorMessage(cause, remoteAddress)
+  }
+
+  test("should reduce the number of active threads when fatal error happens") {
+    val endpoint = mock(classOf[TestRpcEndpoint])
+    when(endpoint.receive).thenThrow(new OutOfMemoryError())
+    val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+    val dispatcher = mock(classOf[Dispatcher])
+    val inbox = new Inbox(rpcEnvRef, endpoint)
+    inbox.post(OneWayMessage(null, "hi"))
+    intercept[OutOfMemoryError] {
+      inbox.process(dispatcher)
+    }
+    assert(inbox.getNumActiveThreads === 0)
+  }
+}
diff --git 
a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcAddressSuite.scala
 
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcAddressSuite.scala
new file mode 100644
index 000000000..1c3fe6ef4
--- /dev/null
+++ 
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcAddressSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc.netty
+
+import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.common.rpc.RpcEndpointAddress
+
+class NettyRpcAddressSuite extends CelebornFunSuite {
+
+  test("toString") {
+    val addr = new RpcEndpointAddress("localhost", 12345, "test")
+    assert(addr.toString === "celeborn://test@localhost:12345")
+  }
+
+  test("toString for client mode") {
+    val addr = RpcEndpointAddress(null, "test")
+    assert(addr.toString === "celeborn-client://test")
+  }
+}
diff --git 
a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnvSuite.scala
 
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnvSuite.scala
new file mode 100644
index 000000000..8afbf5980
--- /dev/null
+++ 
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnvSuite.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc.netty
+
+import java.util.concurrent.ExecutionException
+
+import scala.concurrent.duration._
+
+import org.mockito.Mockito.mock
+import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
+
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.exception.CelebornException
+import org.apache.celeborn.common.network.client.TransportClient
+import org.apache.celeborn.common.rpc._
+import org.apache.celeborn.common.util.ThreadUtils
+
+class NettyRpcEnvSuite extends RpcEnvSuite with TimeLimits {
+
+  implicit private val signaler: Signaler = ThreadSignaler
+
+  override def createRpcEnv(
+      conf: CelebornConf,
+      name: String,
+      port: Int,
+      clientMode: Boolean = false): RpcEnv = {
+    val config = RpcEnvConfig(conf, "test", "localhost", "localhost", port, 0)
+    new NettyRpcEnvFactory().create(config)
+  }
+
+  test("non-existent endpoint") {
+    val uri = RpcEndpointAddress(env.address, "nonexist-endpoint").toString
+    val e = intercept[CelebornException] {
+      env.setupEndpointRef(env.address, "nonexist-endpoint")
+    }
+    assert(e.getCause.isInstanceOf[RpcEndpointNotFoundException])
+    assert(e.getCause.getMessage.contains(uri))
+  }
+
+  test("advertise address different from bind address") {
+    val celebornConf = createCelebornConf()
+    val config = RpcEnvConfig(celebornConf, "test", "localhost", 
"example.com", 0, 0)
+    val env = new NettyRpcEnvFactory().create(config)
+    try {
+      assert(env.address.hostPort.startsWith("example.com:"))
+    } finally {
+      env.shutdown()
+    }
+  }
+
+  test("RequestMessage serialization") {
+    def assertRequestMessageEquals(expected: RequestMessage, actual: 
RequestMessage): Unit = {
+      assert(expected.senderAddress === actual.senderAddress)
+      assert(expected.receiver === actual.receiver)
+      assert(expected.content === actual.content)
+    }
+
+    val nettyEnv = env.asInstanceOf[NettyRpcEnv]
+    val client = mock(classOf[TransportClient])
+    val senderAddress = RpcAddress("localhost", 12345)
+    val receiverAddress = RpcEndpointAddress("localhost", 54321, "test")
+    val receiver = new NettyRpcEndpointRef(nettyEnv.celebornConf, 
receiverAddress, nettyEnv)
+
+    val msg = new RequestMessage(senderAddress, receiver, "foo")
+    assertRequestMessageEquals(
+      msg,
+      RequestMessage(nettyEnv, client, msg.serialize(nettyEnv)))
+
+    val msg2 = new RequestMessage(null, receiver, "foo")
+    assertRequestMessageEquals(
+      msg2,
+      RequestMessage(nettyEnv, client, msg2.serialize(nettyEnv)))
+
+    val msg3 = new RequestMessage(senderAddress, receiver, null)
+    assertRequestMessageEquals(
+      msg3,
+      RequestMessage(nettyEnv, client, msg3.serialize(nettyEnv)))
+  }
+
+  test("StackOverflowError should be sent back and Dispatcher should survive") 
{
+    val numUsableCores = 2
+    val conf = createCelebornConf()
+    val config = RpcEnvConfig(
+      conf,
+      "test",
+      "localhost",
+      "localhost",
+      0,
+      numUsableCores)
+    val anotherEnv = new NettyRpcEnvFactory().create(config)
+    anotherEnv.setupEndpoint(
+      "StackOverflowError",
+      new RpcEndpoint {
+        override val rpcEnv = anotherEnv
+
+        override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+          // scalastyle:off throwerror
+          case msg: String => throw new StackOverflowError
+          // scalastyle:on throwerror
+          case num: Int => context.reply(num)
+        }
+      })
+
+    val rpcEndpointRef = env.setupEndpointRef(anotherEnv.address, 
"StackOverflowError")
+    try {
+      // Send `numUsableCores` messages to trigger `numUsableCores` 
`StackOverflowError`s
+      for (_ <- 0 until numUsableCores) {
+        val e = intercept[CelebornException] {
+          rpcEndpointRef.askSync[String]("hello")
+        }
+        // The root cause `e.getCause.getCause` because it is boxed by Scala 
Promise.
+        assert(e.getCause.isInstanceOf[ExecutionException])
+        assert(e.getCause.getCause.isInstanceOf[StackOverflowError])
+      }
+      failAfter(10.seconds) {
+        assert(rpcEndpointRef.askSync[Int](100) === 100)
+      }
+    } finally {
+      anotherEnv.shutdown()
+      anotherEnv.awaitTermination()
+    }
+  }
+}
diff --git 
a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcHandlerSuite.scala
 
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcHandlerSuite.scala
new file mode 100644
index 000000000..3f6ebb0e6
--- /dev/null
+++ 
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcHandlerSuite.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc.netty
+
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+
+import io.netty.channel.Channel
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito._
+
+import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.common.network.client.{TransportClient, 
TransportResponseHandler}
+import org.apache.celeborn.common.rpc.RpcAddress
+
+class NettyRpcHandlerSuite extends CelebornFunSuite {
+
+  val env = mock(classOf[NettyRpcEnv])
+  when(env.deserialize(any(classOf[TransportClient]), 
any(classOf[ByteBuffer]))(any()))
+    .thenReturn(new RequestMessage(RpcAddress("localhost", 12345), null, null))
+
+  test("receive") {
+    val dispatcher = mock(classOf[Dispatcher])
+    val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+
+    val channel = mock(classOf[Channel])
+    val client = new TransportClient(channel, 
mock(classOf[TransportResponseHandler]))
+    when(channel.remoteAddress()).thenReturn(new 
InetSocketAddress("localhost", 40000))
+    nettyRpcHandler.channelActive(client)
+
+    verify(dispatcher, 
times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
+  }
+
+  test("connectionTerminated") {
+    val dispatcher = mock(classOf[Dispatcher])
+    val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+
+    val channel = mock(classOf[Channel])
+    val client = new TransportClient(channel, 
mock(classOf[TransportResponseHandler]))
+    when(channel.remoteAddress()).thenReturn(new 
InetSocketAddress("localhost", 40000))
+    nettyRpcHandler.channelActive(client)
+
+    when(channel.remoteAddress()).thenReturn(new 
InetSocketAddress("localhost", 40000))
+    nettyRpcHandler.channelInactive(client)
+
+    verify(dispatcher, 
times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
+    verify(dispatcher, times(1)).postToAll(
+      RemoteProcessDisconnected(RpcAddress("localhost", 40000)))
+  }
+
+}

Reply via email to