This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 788b0c340 [CELEBORN-1135] Added tests for the RpcEnv and related
classes
788b0c340 is described below
commit 788b0c340b6c5d31de76fc12983d60f71c3ef811
Author: Chandni Singh <[email protected]>
AuthorDate: Fri Nov 24 09:57:04 2023 +0800
[CELEBORN-1135] Added tests for the RpcEnv and related classes
### What changes were proposed in this pull request?
Added test suites for `RpcEnv`, `NettyRpcEnv`, and other related classes.
These are copied over from Apache Spark. Some of the UTs in Apache Spark
required changes in the source code like
[SPARK-39468](https://issues.apache.org/jira/browse/SPARK-39468) which I didn't
copy over.
### Why are the changes needed?
The change adds unit tests.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Just adds UTs. The source code changes are minimal.
Closes #2107 from otterc/CELEBORN-1135.
Authored-by: Chandni Singh <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../apache/celeborn/common/rpc/netty/Inbox.scala | 28 +-
.../celeborn/common/rpc/RpcAddressSuite.scala | 57 ++
.../apache/celeborn/common/rpc/RpcEnvSuite.scala | 855 +++++++++++++++++++++
.../celeborn/common/rpc/TestRpcEndpoint.scala | 124 +++
.../celeborn/common/rpc/netty/InboxSuite.scala | 158 ++++
.../common/rpc/netty/NettyRpcAddressSuite.scala | 34 +
.../common/rpc/netty/NettyRpcEnvSuite.scala | 138 ++++
.../common/rpc/netty/NettyRpcHandlerSuite.scala | 66 ++
8 files changed, 1458 insertions(+), 2 deletions(-)
diff --git
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
index b24290fe7..09cdd08e2 100644
--- a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
@@ -106,7 +106,7 @@ private[celeborn] class Inbox(
}
}
while (true) {
- safelyCall(endpoint) {
+ safelyCall(endpoint, endpointRef.name) {
message match {
case RpcMessage(_sender, content, context) =>
try {
@@ -218,7 +218,21 @@ private[celeborn] class Inbox(
/**
* Calls action closure, and calls the endpoint's onError function in the
case of exceptions.
*/
- private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = {
+ private def safelyCall(
+ endpoint: RpcEndpoint,
+ endpointRefName: String)(action: => Unit): Unit = {
+ def dealWithFatalError(fatal: Throwable): Unit = {
+ inbox.synchronized {
+ assert(numActiveThreads > 0, "The number of active threads should be
positive.")
+ // Should reduce the number of active threads before throw the error.
+ numActiveThreads -= 1
+ }
+ logError(
+ s"An error happened while processing message in the inbox for
$endpointRefName",
+ fatal)
+ throw fatal
+ }
+
try action
catch {
case NonFatal(e) =>
@@ -230,8 +244,18 @@ private[celeborn] class Inbox(
} else {
logError("Ignoring error", ee)
}
+ case fatal: Throwable =>
+ dealWithFatalError(fatal)
}
+ case fatal: Throwable =>
+ dealWithFatalError(fatal)
}
}
+ // exposed only for testing
+ def getNumActiveThreads: Int = {
+ inbox.synchronized {
+ inbox.numActiveThreads
+ }
+ }
}
diff --git
a/common/src/test/scala/org/apache/celeborn/common/rpc/RpcAddressSuite.scala
b/common/src/test/scala/org/apache/celeborn/common/rpc/RpcAddressSuite.scala
new file mode 100644
index 000000000..9bc322085
--- /dev/null
+++ b/common/src/test/scala/org/apache/celeborn/common/rpc/RpcAddressSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc
+
+import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.common.exception.CelebornException
+
+class RpcAddressSuite extends CelebornFunSuite {
+
+ test("hostPort") {
+ val address = RpcAddress("1.2.3.4", 1234)
+ assert(address.host === "1.2.3.4")
+ assert(address.port === 1234)
+ assert(address.hostPort === "1.2.3.4:1234")
+ }
+
+ test("fromCelebornURL") {
+ val address = RpcAddress.fromCelebornURL("celeborn://1.2.3.4:1234")
+ assert(address.host === "1.2.3.4")
+ assert(address.port === 1234)
+ }
+
+ test("fromCelebornURL: a typo url") {
+ val e = intercept[CelebornException] {
+ RpcAddress.fromCelebornURL("celeborn://1.2. 3.4:1234")
+ }
+ assert("Invalid master URL: celeborn://1.2. 3.4:1234" === e.getMessage)
+ }
+
+ test("fromCelebornURL: invalid scheme") {
+ val e = intercept[CelebornException] {
+ RpcAddress.fromCelebornURL("invalid://1.2.3.4:1234")
+ }
+ assert("Invalid master URL: invalid://1.2.3.4:1234" === e.getMessage)
+ }
+
+ test("toCelebornURL") {
+ val address = RpcAddress("1.2.3.4", 1234)
+ assert(address.toCelebornURL === "celeborn://1.2.3.4:1234")
+ }
+
+}
diff --git
a/common/src/test/scala/org/apache/celeborn/common/rpc/RpcEnvSuite.scala
b/common/src/test/scala/org/apache/celeborn/common/rpc/RpcEnvSuite.scala
new file mode 100644
index 000000000..4843cf744
--- /dev/null
+++ b/common/src/test/scala/org/apache/celeborn/common/rpc/RpcEnvSuite.scala
@@ -0,0 +1,855 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc
+
+import java.io.NotSerializableException
+import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit}
+
+import scala.collection.JavaConverters.collectionAsScalaIterableConverter
+import scala.collection.mutable
+import scala.concurrent.Await
+import scala.concurrent.duration._
+
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito.{mock, never, verify}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.exception.CelebornException
+import org.apache.celeborn.common.util.ThreadUtils
+
+/**
+ * Common tests for an RpcEnv implementation.
+ */
+abstract class RpcEnvSuite extends CelebornFunSuite {
+
+ var env: RpcEnv = _
+
+ def createCelebornConf(): CelebornConf = {
+ new CelebornConf()
+ }
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ val conf = createCelebornConf()
+ env = createRpcEnv(conf, "local", 0)
+
+ }
+
+ override def afterAll(): Unit = {
+ try {
+ if (env != null) {
+ env.shutdown()
+ }
+ } finally {
+ super.afterAll()
+ }
+ }
+
+ def createRpcEnv(conf: CelebornConf, name: String, port: Int, clientMode:
Boolean = false): RpcEnv
+
+ test("send a message locally") {
+ @volatile var message: String = null
+ val rpcEndpointRef = env.setupEndpoint(
+ "send-locally",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive = {
+ case msg: String => message = msg
+ }
+ })
+ rpcEndpointRef.send("hello")
+ eventually(timeout(5.seconds), interval(10.milliseconds)) {
+ assert("hello" === message)
+ }
+ }
+
+ test("send a message remotely") {
+ @volatile var message: String = null
+ // Set up a RpcEndpoint using env
+ env.setupEndpoint(
+ "send-remotely",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case msg: String => message = msg
+ }
+ })
+
+ val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0,
clientMode = true)
+ // Use anotherEnv to find out the RpcEndpointRef
+ val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address,
"send-remotely")
+ try {
+ rpcEndpointRef.send("hello")
+ eventually(timeout(5.seconds), interval(10.milliseconds)) {
+ assert("hello" === message)
+ }
+ } finally {
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ }
+ }
+
+ test("send a RpcEndpointRef") {
+ val endpoint = new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext) = {
+ case "Hello" => context.reply(self)
+ case "Echo" => context.reply("Echo")
+ }
+ }
+ val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint)
+ val newRpcEndpointRef = rpcEndpointRef.askSync[RpcEndpointRef]("Hello")
+ val reply = newRpcEndpointRef.askSync[String]("Echo")
+ assert("Echo" === reply)
+ }
+
+ test("ask a message locally") {
+ val rpcEndpointRef = env.setupEndpoint(
+ "ask-locally",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
+ case msg: String =>
+ context.reply(msg)
+ }
+ })
+ val reply = rpcEndpointRef.askSync[String]("hello")
+ assert("hello" === reply)
+ }
+
+ test("ask a message remotely") {
+ env.setupEndpoint(
+ "ask-remotely",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
+ case msg: String =>
+ context.reply(msg)
+ }
+ })
+
+ val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0,
clientMode = true)
+ // Use anotherEnv to find out the RpcEndpointRef
+ val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address,
"ask-remotely")
+ try {
+ val reply = rpcEndpointRef.askSync[String]("hello")
+ assert("hello" === reply)
+ } finally {
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ }
+ }
+
+ test("ask a message timeout") {
+ env.setupEndpoint(
+ "ask-timeout",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
+ case msg: String =>
+ Thread.sleep(100)
+ context.reply(msg)
+ }
+ })
+
+ val conf = createCelebornConf()
+ val shortProp = "celeborn.rpc.short.timeout"
+ val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true)
+ // Use anotherEnv to find out the RpcEndpointRef
+ val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address,
"ask-timeout")
+ try {
+ val e = intercept[RpcTimeoutException] {
+ rpcEndpointRef.askSync[String]("hello", new RpcTimeout(1.millisecond,
shortProp))
+ }
+ // The Celeborn exception cause should be a RpcTimeoutException with
message indicating the
+ // controlling timeout property
+ assert(e.isInstanceOf[RpcTimeoutException])
+ assert(e.getMessage.contains(shortProp))
+ } finally {
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ }
+ }
+
+ test("onStart and onStop") {
+ val stopLatch = new CountDownLatch(1)
+ val calledMethods = mutable.ArrayBuffer[String]()
+
+ val endpoint = new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def onStart(): Unit = {
+ calledMethods += "start"
+ }
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case msg: String =>
+ }
+
+ override def onStop(): Unit = {
+ calledMethods += "stop"
+ stopLatch.countDown()
+ }
+ }
+ val rpcEndpointRef = env.setupEndpoint("start-stop-test", endpoint)
+ env.stop(rpcEndpointRef)
+ stopLatch.await(10, TimeUnit.SECONDS)
+ assert(List("start", "stop") === calledMethods)
+ }
+
+ test("onError: error in onStart") {
+ @volatile var e: Throwable = null
+ env.setupEndpoint(
+ "onError-onStart",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def onStart(): Unit = {
+ throw new RuntimeException("Oops!")
+ }
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case m =>
+ }
+
+ override def onError(cause: Throwable): Unit = {
+ e = cause
+ }
+ })
+
+ eventually(timeout(5.seconds), interval(10.milliseconds)) {
+ assert(e.getMessage === "Oops!")
+ }
+ }
+
+ test("onError: error in onStop") {
+ @volatile var e: Throwable = null
+ val endpointRef = env.setupEndpoint(
+ "onError-onStop",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case m =>
+ }
+
+ override def onError(cause: Throwable): Unit = {
+ e = cause
+ }
+
+ override def onStop(): Unit = {
+ throw new RuntimeException("Oops!")
+ }
+ })
+
+ env.stop(endpointRef)
+
+ eventually(timeout(5.seconds), interval(10.milliseconds)) {
+ assert(e.getMessage === "Oops!")
+ }
+ }
+
+ test("onError: error in receive") {
+ @volatile var e: Throwable = null
+ val endpointRef = env.setupEndpoint(
+ "onError-receive",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case m => throw new RuntimeException("Oops!")
+ }
+
+ override def onError(cause: Throwable): Unit = {
+ e = cause
+ }
+ })
+
+ endpointRef.send("Foo")
+
+ eventually(timeout(5.seconds), interval(10.milliseconds)) {
+ assert(e.getMessage === "Oops!")
+ }
+ }
+
+ test("self: call in onStart") {
+ @volatile var callSelfSuccessfully = false
+
+ env.setupEndpoint(
+ "self-onStart",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def onStart(): Unit = {
+ self
+ callSelfSuccessfully = true
+ }
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case m =>
+ }
+ })
+
+ eventually(timeout(5.seconds), interval(10.milliseconds)) {
+ // Calling `self` in `onStart` is fine
+ assert(callSelfSuccessfully)
+ }
+ }
+
+ test("self: call in receive") {
+ @volatile var callSelfSuccessfully = false
+
+ val endpointRef = env.setupEndpoint(
+ "self-receive",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case m =>
+ self
+ callSelfSuccessfully = true
+ }
+ })
+
+ endpointRef.send("Foo")
+
+ eventually(timeout(5.seconds), interval(10.milliseconds)) {
+ // Calling `self` in `receive` is fine
+ assert(callSelfSuccessfully)
+ }
+ }
+
+ test("self: call in onStop") {
+ @volatile var selfOption: Option[RpcEndpointRef] = null
+
+ val endpointRef = env.setupEndpoint(
+ "self-onStop",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case m =>
+ }
+
+ override def onStop(): Unit = {
+ selfOption = Option(self)
+ }
+ })
+
+ env.stop(endpointRef)
+
+ eventually(timeout(5.seconds), interval(10.milliseconds)) {
+ // Calling `self` in `onStop` will return null, so selfOption will be
None
+ assert(selfOption.isEmpty)
+ }
+ }
+
+ test("call receive in sequence") {
+ // If a RpcEnv implementation breaks the `receive` contract, hope this
test can expose it
+ for (i <- 0 until 100) {
+ @volatile var result = 0
+ val endpointRef = env.setupEndpoint(
+ s"receive-in-sequence-$i",
+ new ThreadSafeRpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case m => result += 1
+ }
+
+ })
+
+ (0 until 10) foreach { _ =>
+ new Thread {
+ override def run(): Unit = {
+ (0 until 100) foreach { _ =>
+ endpointRef.send("Hello")
+ }
+ }
+ }.start()
+ }
+
+ eventually(timeout(5.seconds), interval(5.milliseconds)) {
+ assert(result == 1000)
+ }
+
+ env.stop(endpointRef)
+ }
+ }
+
+ test("stop(RpcEndpointRef) reentrant") {
+ @volatile var onStopCount = 0
+ val endpointRef = env.setupEndpoint(
+ "stop-reentrant",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case m =>
+ }
+
+ override def onStop(): Unit = {
+ onStopCount += 1
+ }
+ })
+
+ env.stop(endpointRef)
+ env.stop(endpointRef)
+
+ eventually(timeout(5.seconds), interval(5.milliseconds)) {
+ // Calling stop twice should only trigger onStop once.
+ assert(onStopCount == 1)
+ }
+ }
+
+ test("sendWithReply") {
+ val endpointRef = env.setupEndpoint(
+ "sendWithReply",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
+ case m => context.reply("ack")
+ }
+ })
+
+ val f = endpointRef.ask[String]("Hi")
+ val ack = ThreadUtils.awaitResult(f, 5.seconds)
+ assert("ack" === ack)
+
+ env.stop(endpointRef)
+ }
+
+ test("sendWithReply: remotely") {
+ env.setupEndpoint(
+ "sendWithReply-remotely",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
+ case m => context.reply("ack")
+ }
+ })
+
+ val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0,
clientMode = true)
+ // Use anotherEnv to find out the RpcEndpointRef
+ val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address,
"sendWithReply-remotely")
+ try {
+ val f = rpcEndpointRef.ask[String]("hello")
+ val ack = ThreadUtils.awaitResult(f, 5.seconds)
+ assert("ack" === ack)
+ } finally {
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ }
+ }
+
+ test("sendWithReply: error") {
+ val endpointRef = env.setupEndpoint(
+ "sendWithReply-error",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
+ case m => context.sendFailure(new CelebornException("Oops"))
+ }
+ })
+
+ val f = endpointRef.ask[String]("Hi")
+ val e = intercept[CelebornException] {
+ ThreadUtils.awaitResult(f, 5.seconds)
+ }
+ assert("Oops" === e.getCause.getMessage)
+
+ env.stop(endpointRef)
+ }
+
+ test("sendWithReply: remotely error") {
+ env.setupEndpoint(
+ "sendWithReply-remotely-error",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
+ case msg: String => context.sendFailure(new
CelebornException("Oops"))
+ }
+ })
+
+ val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0,
clientMode = true)
+ // Use anotherEnv to find out the RpcEndpointRef
+ val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address,
"sendWithReply-remotely-error")
+ try {
+ val f = rpcEndpointRef.ask[String]("hello")
+ val e = intercept[CelebornException] {
+ ThreadUtils.awaitResult(f, 5.seconds)
+ }
+ assert("Oops" === e.getCause.getMessage)
+ } finally {
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ }
+ }
+
+ /**
+ * Setup an [[RpcEndpoint]] to collect all network events.
+ *
+ * @return the [[RpcEndpointRef]] and a `ConcurrentLinkedQueue` that
contains network events.
+ */
+ private def setupNetworkEndpoint(
+ _env: RpcEnv,
+ name: String): (RpcEndpointRef, ConcurrentLinkedQueue[(Any, Any)]) = {
+ val events = new ConcurrentLinkedQueue[(Any, Any)]
+ val ref = _env.setupEndpoint(
+ "network-events-non-client",
+ new ThreadSafeRpcEndpoint {
+ override val rpcEnv = _env
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case "hello" =>
+ case m => events.add("receive" -> m)
+ }
+
+ override def onConnected(remoteAddress: RpcAddress): Unit = {
+ events.add("onConnected" -> remoteAddress)
+ }
+
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ events.add("onDisconnected" -> remoteAddress)
+ }
+
+ override def onNetworkError(cause: Throwable, remoteAddress:
RpcAddress): Unit = {
+ events.add("onNetworkError" -> remoteAddress)
+ }
+
+ })
+ (ref, events)
+ }
+
+ test("network events in sever RpcEnv when another RpcEnv is in server mode")
{
+ val serverEnv1 = createRpcEnv(createCelebornConf(), "server1", 0,
clientMode = false)
+ val serverEnv2 = createRpcEnv(createCelebornConf(), "server2", 0,
clientMode = false)
+ val (_, events) = setupNetworkEndpoint(serverEnv1, "network-events")
+ val (serverRef2, _) = setupNetworkEndpoint(serverEnv2, "network-events")
+ try {
+ val serverRefInServer2 = serverEnv1.setupEndpointRef(serverRef2.address,
serverRef2.name)
+ // Send a message to set up the connection
+ serverRefInServer2.send("hello")
+
+ eventually(timeout(5.seconds), interval(5.milliseconds)) {
+ assert(events.contains(("onConnected", serverEnv2.address)))
+ }
+
+ serverEnv2.shutdown()
+ serverEnv2.awaitTermination()
+
+ eventually(timeout(5.seconds), interval(5.milliseconds)) {
+ assert(events.contains(("onConnected", serverEnv2.address)))
+ assert(events.contains(("onDisconnected", serverEnv2.address)))
+ }
+ } finally {
+ serverEnv1.shutdown()
+ serverEnv2.shutdown()
+ serverEnv1.awaitTermination()
+ serverEnv2.awaitTermination()
+ }
+ }
+
+ test("network events in sever RpcEnv when another RpcEnv is in client mode")
{
+ val serverEnv = createRpcEnv(createCelebornConf(), "server", 0, clientMode
= false)
+ val (serverRef, events) = setupNetworkEndpoint(serverEnv, "network-events")
+ val clientEnv = createRpcEnv(createCelebornConf(), "client", 0, clientMode
= true)
+ try {
+ val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address,
serverRef.name)
+ // Send a message to set up the connection
+ serverRefInClient.send("hello")
+
+ eventually(timeout(5.seconds), interval(5.milliseconds)) {
+ // We don't know the exact client address but at least we can verify
the message type
+ assert(events.asScala.map(_._1).exists(_ == "onConnected"))
+ }
+
+ clientEnv.shutdown()
+ clientEnv.awaitTermination()
+
+ eventually(timeout(5.seconds), interval(5.milliseconds)) {
+ // We don't know the exact client address but at least we can verify
the message type
+ assert(events.asScala.map(_._1).exists(_ == "onConnected"))
+ assert(events.asScala.map(_._1).exists(_ == "onDisconnected"))
+ }
+ } finally {
+ clientEnv.shutdown()
+ serverEnv.shutdown()
+ clientEnv.awaitTermination()
+ serverEnv.awaitTermination()
+ }
+ }
+
+ test("network events in client RpcEnv when another RpcEnv is in server
mode") {
+ val clientEnv = createRpcEnv(createCelebornConf(), "client", 0, clientMode
= true)
+ val serverEnv = createRpcEnv(createCelebornConf(), "server", 0, clientMode
= false)
+ val (_, events) = setupNetworkEndpoint(clientEnv, "network-events")
+ val (serverRef, _) = setupNetworkEndpoint(serverEnv, "network-events")
+ try {
+ val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address,
serverRef.name)
+ // Send a message to set up the connection
+ serverRefInClient.send("hello")
+
+ eventually(timeout(5.seconds), interval(5.milliseconds)) {
+ assert(events.contains(("onConnected", serverEnv.address)))
+ }
+
+ serverEnv.shutdown()
+ serverEnv.awaitTermination()
+
+ eventually(timeout(5.seconds), interval(5.milliseconds)) {
+ assert(events.contains(("onConnected", serverEnv.address)))
+ assert(events.contains(("onDisconnected", serverEnv.address)))
+ }
+ } finally {
+ clientEnv.shutdown()
+ serverEnv.shutdown()
+ clientEnv.awaitTermination()
+ serverEnv.awaitTermination()
+ }
+ }
+
+ test("sendWithReply: unserializable error") {
+ env.setupEndpoint(
+ "sendWithReply-unserializable-error",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
+ case msg: String => context.sendFailure(new UnserializableException)
+ }
+ })
+
+ val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0,
clientMode = true)
+ // Use anotherEnv to find out the RpcEndpointRef
+ val rpcEndpointRef =
+ anotherEnv.setupEndpointRef(env.address,
"sendWithReply-unserializable-error")
+ try {
+ val f = rpcEndpointRef.ask[String]("hello")
+ val e = intercept[CelebornException] {
+ ThreadUtils.awaitResult(f, 1.second)
+ }
+ assert(e.getCause.isInstanceOf[NotSerializableException])
+ } finally {
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ }
+ }
+
+ test("port conflict") {
+ val anotherEnv = createRpcEnv(createCelebornConf(), "remote",
env.address.port)
+ try {
+ assert(anotherEnv.address.port != env.address.port)
+ } finally {
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ }
+ }
+
+ private def testSend(conf: CelebornConf): Unit = {
+ val localEnv = createRpcEnv(conf, "authentication-local", 0)
+ val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode
= true)
+
+ try {
+ @volatile var message: String = null
+ localEnv.setupEndpoint(
+ "send-authentication",
+ new RpcEndpoint {
+ override val rpcEnv = localEnv
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case msg: String => message = msg
+ }
+ })
+ val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address,
"send-authentication")
+ rpcEndpointRef.send("hello")
+ eventually(timeout(5.seconds), interval(10.milliseconds)) {
+ assert("hello" === message)
+ }
+ } finally {
+ localEnv.shutdown()
+ localEnv.awaitTermination()
+ remoteEnv.shutdown()
+ remoteEnv.awaitTermination()
+ }
+ }
+
+ private def testAsk(conf: CelebornConf): Unit = {
+ val localEnv = createRpcEnv(conf, "authentication-local", 0)
+ val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode
= true)
+
+ try {
+ localEnv.setupEndpoint(
+ "ask-authentication",
+ new RpcEndpoint {
+ override val rpcEnv = localEnv
+
+ override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
+ case msg: String =>
+ context.reply(msg)
+ }
+ })
+ val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address,
"ask-authentication")
+ val reply = rpcEndpointRef.askSync[String]("hello")
+ assert("hello" === reply)
+ } finally {
+ localEnv.shutdown()
+ localEnv.awaitTermination()
+ remoteEnv.shutdown()
+ remoteEnv.awaitTermination()
+ }
+ }
+
+ test("construct RpcTimeout with conf property") {
+ val conf = new CelebornConf()
+
+ val testProp = "celeborn.ask.test.timeout"
+ val testDurationSeconds = 30
+ val secondaryProp = "celeborn.ask.secondary.timeout"
+
+ conf.set(testProp, s"${testDurationSeconds}s")
+ conf.set(secondaryProp, "100s")
+
+ // Construct RpcTimeout with a single property
+ val rt1 = RpcTimeout(conf, testProp)
+ assert(testDurationSeconds === rt1.duration.toSeconds)
+
+ // Construct RpcTimeout with prioritized list of properties
+ val rt2 = RpcTimeout(conf, Seq("celeborn.ask.invalid.timeout", testProp,
secondaryProp), "1s")
+ assert(testDurationSeconds === rt2.duration.toSeconds)
+
+ // Construct RpcTimeout with default value,
+ val defaultProp = "celeborn.ask.default.timeout"
+ val defaultDurationSeconds = 1
+ val rt3 = RpcTimeout(conf, Seq(defaultProp),
defaultDurationSeconds.toString + "s")
+ assert(defaultDurationSeconds === rt3.duration.toSeconds)
+ assert(rt3.timeoutProp.contains(defaultProp))
+
+ // Try to construct RpcTimeout with an unconfigured property
+ intercept[NoSuchElementException] {
+ RpcTimeout(conf, "celeborn.ask.invalid.timeout")
+ }
+ }
+
+ test("ask a message timeout on Future using RpcTimeout") {
+ case class NeverReply(msg: String)
+
+ val rpcEndpointRef = env.setupEndpoint(
+ "ask-future",
+ new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
+ case msg: String => context.reply(msg)
+ case _: NeverReply =>
+ }
+ })
+
+ val longTimeout = new RpcTimeout(1.second, "celeborn.rpc.long.timeout")
+ val shortTimeout = new RpcTimeout(10.milliseconds,
"celeborn.rpc.short.timeout")
+
+ // Ask with immediate response, should complete successfully
+ val fut1 = rpcEndpointRef.ask[String]("hello", longTimeout)
+ val reply1 = longTimeout.awaitResult(fut1)
+ assert("hello" === reply1)
+
+ // Ask with a delayed response and wait for response immediately that
should timeout
+ val fut2 = rpcEndpointRef.ask[String](NeverReply("doh"), shortTimeout)
+ val reply2 =
+ intercept[RpcTimeoutException] {
+ shortTimeout.awaitResult(fut2)
+ }.getMessage
+
+ // RpcTimeout.awaitResult should have added the property to the
TimeoutException message
+ assert(reply2.contains(shortTimeout.timeoutProp))
+
+ // Ask with delayed response and allow the Future to timeout before
ThreadUtils.awaitResult
+ val fut3 = rpcEndpointRef.ask[String](NeverReply("goodbye"), shortTimeout)
+
+ // scalastyle:off awaitresult
+ // Allow future to complete with failure using plain Await.result, this
will return
+ // once the future is complete to verify addMessageIfTimeout was invoked
+ val reply3 =
+ intercept[RpcTimeoutException] {
+ Await.result(fut3, 2.seconds)
+ }.getMessage
+ // scalastyle:on awaitresult
+
+ // When the future timed out, the recover callback should have used
+ // RpcTimeout.addMessageIfTimeout to add the property to the
TimeoutException message
+ assert(reply3.contains(shortTimeout.timeoutProp))
+
+ // Use RpcTimeout.awaitResult to process Future, since it has already
failed with
+ // RpcTimeoutException, the same RpcTimeoutException should be thrown
+ val reply4 =
+ intercept[RpcTimeoutException] {
+ shortTimeout.awaitResult(fut3)
+ }.getMessage
+
+ // Ensure description is not in message twice after addMessageIfTimeout
and awaitResult
+ assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1)
+ }
+
+ test("RpcEnv.shutdown should not fire onDisconnected events") {
+ env.setupEndpoint(
+ "test_ep_11212023",
+ new RpcEndpoint {
+ override val rpcEnv: RpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
+ case m => context.reply(m)
+ }
+ })
+
+ val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0)
+ val endpoint = mock(classOf[RpcEndpoint])
+ anotherEnv.setupEndpoint("test_ep_11212023", endpoint)
+
+ val ref = anotherEnv.setupEndpointRef(env.address, "test_ep_11212023")
+ // Make sure the connect is set up
+ assert(ref.askSync[String]("hello") === "hello")
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+
+ env.stop(ref)
+
+ verify(endpoint).onStop()
+ verify(endpoint, never()).onDisconnected(any())
+ verify(endpoint, never()).onNetworkError(any(), any())
+ }
+}
+
+case class Register(ref: RpcEndpointRef)
+
+class UnserializableClass
+
+class UnserializableException extends Exception {
+ private val unserializableField = new UnserializableClass
+}
diff --git
a/common/src/test/scala/org/apache/celeborn/common/rpc/TestRpcEndpoint.scala
b/common/src/test/scala/org/apache/celeborn/common/rpc/TestRpcEndpoint.scala
new file mode 100644
index 000000000..9d343b037
--- /dev/null
+++ b/common/src/test/scala/org/apache/celeborn/common/rpc/TestRpcEndpoint.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalactic.TripleEquals
+import org.scalatest.Assertions._
+
+class TestRpcEndpoint extends ThreadSafeRpcEndpoint with TripleEquals {
+
+ override val rpcEnv: RpcEnv = null
+
+ @volatile private var receiveMessages = ArrayBuffer[Any]()
+
+ @volatile private var receiveAndReplyMessages = ArrayBuffer[Any]()
+
+ @volatile private var onConnectedMessages = ArrayBuffer[RpcAddress]()
+
+ @volatile private var onDisconnectedMessages = ArrayBuffer[RpcAddress]()
+
+ @volatile private var onNetworkErrorMessages = ArrayBuffer[(Throwable,
RpcAddress)]()
+
+ @volatile private var started = false
+
+ @volatile private var stopped = false
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case message: Any => receiveMessages += message
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any,
Unit] = {
+ case message: Any => receiveAndReplyMessages += message
+ }
+
+ override def onConnected(remoteAddress: RpcAddress): Unit = {
+ onConnectedMessages += remoteAddress
+ }
+
+ /**
+ * Invoked when some network error happens in the connection between the
current node and
+ * `remoteAddress`.
+ */
+ override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress):
Unit = {
+ onNetworkErrorMessages += cause -> remoteAddress
+ }
+
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ onDisconnectedMessages += remoteAddress
+ }
+
+ def numReceiveMessages: Int = receiveMessages.size
+
+ override def onStart(): Unit = {
+ started = true
+ }
+
+ override def onStop(): Unit = {
+ stopped = true
+ }
+
+ def verifyStarted(): Unit = {
+ assert(started, "RpcEndpoint is not started")
+ }
+
+ def verifyStopped(): Unit = {
+ assert(stopped, "RpcEndpoint is not stopped")
+ }
+
+ def verifyReceiveMessages(expected: Seq[Any]): Unit = {
+ assert(receiveMessages === expected)
+ }
+
+ def verifySingleReceiveMessage(message: Any): Unit = {
+ verifyReceiveMessages(List(message))
+ }
+
+ def verifyReceiveAndReplyMessages(expected: Seq[Any]): Unit = {
+ assert(receiveAndReplyMessages === expected)
+ }
+
+ def verifySingleReceiveAndReplyMessage(message: Any): Unit = {
+ verifyReceiveAndReplyMessages(List(message))
+ }
+
+ def verifySingleOnConnectedMessage(remoteAddress: RpcAddress): Unit = {
+ verifyOnConnectedMessages(List(remoteAddress))
+ }
+
+ def verifyOnConnectedMessages(expected: Seq[RpcAddress]): Unit = {
+ assert(onConnectedMessages === expected)
+ }
+
+ def verifySingleOnDisconnectedMessage(remoteAddress: RpcAddress): Unit = {
+ verifyOnDisconnectedMessages(List(remoteAddress))
+ }
+
+ def verifyOnDisconnectedMessages(expected: Seq[RpcAddress]): Unit = {
+ assert(onDisconnectedMessages === expected)
+ }
+
+ def verifySingleOnNetworkErrorMessage(cause: Throwable, remoteAddress:
RpcAddress): Unit = {
+ verifyOnNetworkErrorMessages(List(cause -> remoteAddress))
+ }
+
+ def verifyOnNetworkErrorMessages(expected: Seq[(Throwable, RpcAddress)]):
Unit = {
+ assert(onNetworkErrorMessages === expected)
+ }
+}
diff --git
a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala
new file mode 100644
index 000000000..a8bc826dd
--- /dev/null
+++
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc.netty
+
+import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.mockito.Mockito._
+
+import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.common.rpc.{RpcAddress, TestRpcEndpoint}
+
+class InboxSuite extends CelebornFunSuite {
+
+ test("post") {
+ val endpoint = new TestRpcEndpoint
+ val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+ val dispatcher = mock(classOf[Dispatcher])
+
+ val inbox = new Inbox(rpcEnvRef, endpoint)
+ val message = OneWayMessage(null, "hi")
+ inbox.post(message)
+ inbox.process(dispatcher)
+ assert(inbox.isEmpty)
+
+ endpoint.verifySingleReceiveMessage("hi")
+
+ inbox.stop()
+ inbox.process(dispatcher)
+ assert(inbox.isEmpty)
+ endpoint.verifyStarted()
+ endpoint.verifyStopped()
+ }
+
+ test("post: with reply") {
+ val endpoint = new TestRpcEndpoint
+ val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+ val dispatcher = mock(classOf[Dispatcher])
+
+ val inbox = new Inbox(rpcEnvRef, endpoint)
+ val message = RpcMessage(null, "hi", null)
+ inbox.post(message)
+ inbox.process(dispatcher)
+ assert(inbox.isEmpty)
+
+ endpoint.verifySingleReceiveAndReplyMessage("hi")
+ }
+
+ test("post: multiple threads") {
+ val endpoint = new TestRpcEndpoint
+ val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+ val dispatcher = mock(classOf[Dispatcher])
+
+ val numDroppedMessages = new AtomicInteger(0)
+ val inbox = new Inbox(rpcEnvRef, endpoint) {
+ override def onDrop(message: InboxMessage): Unit = {
+ numDroppedMessages.incrementAndGet()
+ }
+ }
+
+ val exitLatch = new CountDownLatch(10)
+
+ for (_ <- 0 until 10) {
+ new Thread {
+ override def run(): Unit = {
+ for (_ <- 0 until 100) {
+ val message = OneWayMessage(null, "hi")
+ inbox.post(message)
+ }
+ exitLatch.countDown()
+ }
+ }.start()
+ }
+ // Try to process some messages
+ inbox.process(dispatcher)
+ inbox.stop()
+ // After `stop` is called, further messages will be dropped. However,
while `stop` is called,
+ // some messages may be post to Inbox, so process them here.
+ inbox.process(dispatcher)
+ assert(inbox.isEmpty)
+
+ exitLatch.await(30, TimeUnit.SECONDS)
+
+ assert(1000 === endpoint.numReceiveMessages + numDroppedMessages.get)
+ endpoint.verifyStarted()
+ endpoint.verifyStopped()
+ }
+
+ test("post: Associated") {
+ val endpoint = new TestRpcEndpoint
+ val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+ val dispatcher = mock(classOf[Dispatcher])
+ val remoteAddress = RpcAddress("localhost", 11111)
+
+ val inbox = new Inbox(rpcEnvRef, endpoint)
+ inbox.post(RemoteProcessConnected(remoteAddress))
+ inbox.process(dispatcher)
+
+ endpoint.verifySingleOnConnectedMessage(remoteAddress)
+ }
+
+ test("post: Disassociated") {
+ val endpoint = new TestRpcEndpoint
+ val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+ val dispatcher = mock(classOf[Dispatcher])
+
+ val remoteAddress = RpcAddress("localhost", 11111)
+
+ val inbox = new Inbox(rpcEnvRef, endpoint)
+ inbox.post(RemoteProcessDisconnected(remoteAddress))
+ inbox.process(dispatcher)
+
+ endpoint.verifySingleOnDisconnectedMessage(remoteAddress)
+ }
+
+ test("post: AssociationError") {
+ val endpoint = new TestRpcEndpoint
+ val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+ val dispatcher = mock(classOf[Dispatcher])
+
+ val remoteAddress = RpcAddress("localhost", 11111)
+ val cause = new RuntimeException("Oops")
+
+ val inbox = new Inbox(rpcEnvRef, endpoint)
+ inbox.post(RemoteProcessConnectionError(cause, remoteAddress))
+ inbox.process(dispatcher)
+
+ endpoint.verifySingleOnNetworkErrorMessage(cause, remoteAddress)
+ }
+
+ test("should reduce the number of active threads when fatal error happens") {
+ val endpoint = mock(classOf[TestRpcEndpoint])
+ when(endpoint.receive).thenThrow(new OutOfMemoryError())
+ val rpcEnvRef = mock(classOf[NettyRpcEndpointRef])
+ val dispatcher = mock(classOf[Dispatcher])
+ val inbox = new Inbox(rpcEnvRef, endpoint)
+ inbox.post(OneWayMessage(null, "hi"))
+ intercept[OutOfMemoryError] {
+ inbox.process(dispatcher)
+ }
+ assert(inbox.getNumActiveThreads === 0)
+ }
+}
diff --git
a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcAddressSuite.scala
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcAddressSuite.scala
new file mode 100644
index 000000000..1c3fe6ef4
--- /dev/null
+++
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcAddressSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc.netty
+
+import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.common.rpc.RpcEndpointAddress
+
+class NettyRpcAddressSuite extends CelebornFunSuite {
+
+ test("toString") {
+ val addr = new RpcEndpointAddress("localhost", 12345, "test")
+ assert(addr.toString === "celeborn://test@localhost:12345")
+ }
+
+ test("toString for client mode") {
+ val addr = RpcEndpointAddress(null, "test")
+ assert(addr.toString === "celeborn-client://test")
+ }
+}
diff --git
a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnvSuite.scala
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnvSuite.scala
new file mode 100644
index 000000000..8afbf5980
--- /dev/null
+++
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnvSuite.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc.netty
+
+import java.util.concurrent.ExecutionException
+
+import scala.concurrent.duration._
+
+import org.mockito.Mockito.mock
+import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
+
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.exception.CelebornException
+import org.apache.celeborn.common.network.client.TransportClient
+import org.apache.celeborn.common.rpc._
+import org.apache.celeborn.common.util.ThreadUtils
+
+class NettyRpcEnvSuite extends RpcEnvSuite with TimeLimits {
+
+ implicit private val signaler: Signaler = ThreadSignaler
+
+ override def createRpcEnv(
+ conf: CelebornConf,
+ name: String,
+ port: Int,
+ clientMode: Boolean = false): RpcEnv = {
+ val config = RpcEnvConfig(conf, "test", "localhost", "localhost", port, 0)
+ new NettyRpcEnvFactory().create(config)
+ }
+
+ test("non-existent endpoint") {
+ val uri = RpcEndpointAddress(env.address, "nonexist-endpoint").toString
+ val e = intercept[CelebornException] {
+ env.setupEndpointRef(env.address, "nonexist-endpoint")
+ }
+ assert(e.getCause.isInstanceOf[RpcEndpointNotFoundException])
+ assert(e.getCause.getMessage.contains(uri))
+ }
+
+ test("advertise address different from bind address") {
+ val celebornConf = createCelebornConf()
+ val config = RpcEnvConfig(celebornConf, "test", "localhost",
"example.com", 0, 0)
+ val env = new NettyRpcEnvFactory().create(config)
+ try {
+ assert(env.address.hostPort.startsWith("example.com:"))
+ } finally {
+ env.shutdown()
+ }
+ }
+
+ test("RequestMessage serialization") {
+ def assertRequestMessageEquals(expected: RequestMessage, actual:
RequestMessage): Unit = {
+ assert(expected.senderAddress === actual.senderAddress)
+ assert(expected.receiver === actual.receiver)
+ assert(expected.content === actual.content)
+ }
+
+ val nettyEnv = env.asInstanceOf[NettyRpcEnv]
+ val client = mock(classOf[TransportClient])
+ val senderAddress = RpcAddress("localhost", 12345)
+ val receiverAddress = RpcEndpointAddress("localhost", 54321, "test")
+ val receiver = new NettyRpcEndpointRef(nettyEnv.celebornConf,
receiverAddress, nettyEnv)
+
+ val msg = new RequestMessage(senderAddress, receiver, "foo")
+ assertRequestMessageEquals(
+ msg,
+ RequestMessage(nettyEnv, client, msg.serialize(nettyEnv)))
+
+ val msg2 = new RequestMessage(null, receiver, "foo")
+ assertRequestMessageEquals(
+ msg2,
+ RequestMessage(nettyEnv, client, msg2.serialize(nettyEnv)))
+
+ val msg3 = new RequestMessage(senderAddress, receiver, null)
+ assertRequestMessageEquals(
+ msg3,
+ RequestMessage(nettyEnv, client, msg3.serialize(nettyEnv)))
+ }
+
+ test("StackOverflowError should be sent back and Dispatcher should survive")
{
+ val numUsableCores = 2
+ val conf = createCelebornConf()
+ val config = RpcEnvConfig(
+ conf,
+ "test",
+ "localhost",
+ "localhost",
+ 0,
+ numUsableCores)
+ val anotherEnv = new NettyRpcEnvFactory().create(config)
+ anotherEnv.setupEndpoint(
+ "StackOverflowError",
+ new RpcEndpoint {
+ override val rpcEnv = anotherEnv
+
+ override def receiveAndReply(context: RpcCallContext):
PartialFunction[Any, Unit] = {
+ // scalastyle:off throwerror
+ case msg: String => throw new StackOverflowError
+ // scalastyle:on throwerror
+ case num: Int => context.reply(num)
+ }
+ })
+
+ val rpcEndpointRef = env.setupEndpointRef(anotherEnv.address,
"StackOverflowError")
+ try {
+ // Send `numUsableCores` messages to trigger `numUsableCores`
`StackOverflowError`s
+ for (_ <- 0 until numUsableCores) {
+ val e = intercept[CelebornException] {
+ rpcEndpointRef.askSync[String]("hello")
+ }
+ // The root cause `e.getCause.getCause` because it is boxed by Scala
Promise.
+ assert(e.getCause.isInstanceOf[ExecutionException])
+ assert(e.getCause.getCause.isInstanceOf[StackOverflowError])
+ }
+ failAfter(10.seconds) {
+ assert(rpcEndpointRef.askSync[Int](100) === 100)
+ }
+ } finally {
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ }
+ }
+}
diff --git
a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcHandlerSuite.scala
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcHandlerSuite.scala
new file mode 100644
index 000000000..3f6ebb0e6
--- /dev/null
+++
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcHandlerSuite.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc.netty
+
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+
+import io.netty.channel.Channel
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito._
+
+import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.common.network.client.{TransportClient,
TransportResponseHandler}
+import org.apache.celeborn.common.rpc.RpcAddress
+
+class NettyRpcHandlerSuite extends CelebornFunSuite {
+
+ val env = mock(classOf[NettyRpcEnv])
+ when(env.deserialize(any(classOf[TransportClient]),
any(classOf[ByteBuffer]))(any()))
+ .thenReturn(new RequestMessage(RpcAddress("localhost", 12345), null, null))
+
+ test("receive") {
+ val dispatcher = mock(classOf[Dispatcher])
+ val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+
+ val channel = mock(classOf[Channel])
+ val client = new TransportClient(channel,
mock(classOf[TransportResponseHandler]))
+ when(channel.remoteAddress()).thenReturn(new
InetSocketAddress("localhost", 40000))
+ nettyRpcHandler.channelActive(client)
+
+ verify(dispatcher,
times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
+ }
+
+ test("connectionTerminated") {
+ val dispatcher = mock(classOf[Dispatcher])
+ val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+
+ val channel = mock(classOf[Channel])
+ val client = new TransportClient(channel,
mock(classOf[TransportResponseHandler]))
+ when(channel.remoteAddress()).thenReturn(new
InetSocketAddress("localhost", 40000))
+ nettyRpcHandler.channelActive(client)
+
+ when(channel.remoteAddress()).thenReturn(new
InetSocketAddress("localhost", 40000))
+ nettyRpcHandler.channelInactive(client)
+
+ verify(dispatcher,
times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
+ verify(dispatcher, times(1)).postToAll(
+ RemoteProcessDisconnected(RpcAddress("localhost", 40000)))
+ }
+
+}