This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 0d72c9595 [CELEBORN-1365] Ensure that a client cannot update the
metadata belonging to a different application
0d72c9595 is described below
commit 0d72c95958d32e138d136b3dc8e62046a6c08cad
Author: Chandni Singh <[email protected]>
AuthorDate: Mon Apr 8 10:35:13 2024 +0800
[CELEBORN-1365] Ensure that a client cannot update the metadata belonging
to a different application
### What changes were proposed in this pull request?
This ensures that an authenticated client does not update the metadata
belonging to another application.
### Why are the changes needed?
The changes are needed for authentication support.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Closes #2441 from otterc/CELEBORN-1365.
Authored-by: Chandni Singh <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../common/network/server/BaseMessageHandler.java | 2 +-
.../org/apache/celeborn/common/rpc/RpcEndpoint.scala | 17 +++++++++++++++++
.../apache/celeborn/common/rpc/netty/Dispatcher.scala | 9 ++++++---
.../celeborn/common/rpc/netty/NettyRpcCallContext.scala | 5 +++--
.../apache/celeborn/common/rpc/netty/NettyRpcEnv.scala | 2 +-
.../apache/celeborn/service/deploy/master/Master.scala | 5 +++++
.../celeborn/service/deploy/worker/Controller.scala | 3 +++
7 files changed, 36 insertions(+), 7 deletions(-)
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
b/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
index 974166745..afc869de5 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
@@ -49,7 +49,7 @@ public class BaseMessageHandler {
protected void checkAuth(TransportClient client, String appId) {
if (client.getClientId() != null && !client.getClientId().equals(appId)) {
- throw new SecurityException(
+ throw new IllegalStateException(
String.format(
"Client for %s not authorized for application %s.",
client.getClientId(), appId));
}
diff --git
a/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpoint.scala
b/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpoint.scala
index cf69c8242..8d12c227a 100644
--- a/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpoint.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpoint.scala
@@ -18,6 +18,8 @@
package org.apache.celeborn.common.rpc
import org.apache.celeborn.common.exception.CelebornException
+import org.apache.celeborn.common.network.client.TransportClient
+import org.apache.celeborn.common.rpc.netty.RemoteNettyRpcCallContext
/**
* A factory class to create the [[RpcEnv]]. It must have an empty constructor
so that it can be
@@ -134,6 +136,21 @@ trait RpcEndpoint {
rpcEnv.stop(_self)
}
}
+
+ def checkAuth(context: RpcCallContext, appId: String): Unit = {
+ context match {
+ case remoteContext: RemoteNettyRpcCallContext =>
+ checkAuth(remoteContext.transportClient, appId)
+ case _ =>
+ // Do nothing if the context is not RemoteNettyRpcCallContext
+ }
+ }
+
+ private def checkAuth(client: TransportClient, appId: String): Unit = {
+ if (client.getClientId != null && client.getClientId != appId)
+ throw new IllegalStateException(
+ s"Client for ${client.getClientId} not authorized for application
$appId.")
+ }
}
/**
diff --git
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala
index b137ae2f5..391b64186 100644
---
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala
@@ -26,7 +26,7 @@ import scala.util.control.NonFatal
import org.apache.celeborn.common.exception.CelebornException
import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.network.client.RpcResponseCallback
+import org.apache.celeborn.common.network.client.{RpcResponseCallback,
TransportClient}
import org.apache.celeborn.common.rpc._
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils}
@@ -120,9 +120,12 @@ private[celeborn] class Dispatcher(nettyEnv: NettyRpcEnv)
extends Logging {
}
/** Posts a message sent by a remote endpoint. */
- def postRemoteMessage(message: RequestMessage, callback:
RpcResponseCallback): Unit = {
+ def postRemoteMessage(
+ message: RequestMessage,
+ callback: RpcResponseCallback,
+ client: TransportClient): Unit = {
val rpcCallContext =
- new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
+ new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress,
client)
val rpcMessage = RpcMessage(message.senderAddress, message.content,
rpcCallContext)
postMessage(message.receiver.name, rpcMessage, (e) =>
callback.onFailure(e))
}
diff --git
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcCallContext.scala
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcCallContext.scala
index 82a6f8547..3a51e7c74 100644
---
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcCallContext.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcCallContext.scala
@@ -20,7 +20,7 @@ package org.apache.celeborn.common.rpc.netty
import scala.concurrent.Promise
import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.network.client.RpcResponseCallback
+import org.apache.celeborn.common.network.client.{RpcResponseCallback,
TransportClient}
import org.apache.celeborn.common.rpc.{RpcAddress, RpcCallContext}
abstract private[celeborn] class NettyRpcCallContext(override val
senderAddress: RpcAddress)
@@ -57,7 +57,8 @@ private[celeborn] class LocalNettyRpcCallContext(
private[celeborn] class RemoteNettyRpcCallContext(
val nettyEnv: NettyRpcEnv,
val callback: RpcResponseCallback,
- senderAddress: RpcAddress)
+ senderAddress: RpcAddress,
+ val transportClient: TransportClient)
extends NettyRpcCallContext(senderAddress) {
override protected def send(message: Any): Unit = {
diff --git
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
index 824935051..b0f470efe 100644
---
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
@@ -585,7 +585,7 @@ private[celeborn] class NettyRpcHandler(
try {
val message = requestMessage.body().nioByteBuffer()
val messageToDispatch = internalReceive(client, message)
- dispatcher.postRemoteMessage(messageToDispatch, callback)
+ dispatcher.postRemoteMessage(messageToDispatch, callback, client)
} catch {
case e: Exception =>
val rpcReq = requestMessage.asInstanceOf[RpcRequest]
diff --git
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
index 05ae53cc3..736cc7e45 100644
---
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
+++
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
@@ -397,6 +397,7 @@ private[celeborn] class Master(
requestId,
shouldResponse) =>
logDebug(s"Received heartbeat from app $appId")
+ checkAuth(context, appId)
executeWithLeaderChecker(
context,
handleHeartbeatFromApplication(
@@ -446,6 +447,7 @@ private[celeborn] class Master(
case requestSlots @ RequestSlots(applicationId, _, _, _, _, _, _, _, _, _,
_) =>
logTrace(s"Received RequestSlots request $requestSlots.")
+ checkAuth(context, applicationId)
executeWithLeaderChecker(context, handleRequestSlots(context,
requestSlots))
case pb: PbUnregisterShuffle =>
@@ -453,6 +455,7 @@ private[celeborn] class Master(
val shuffleId = pb.getShuffleId
val requestId = pb.getRequestId
logDebug(s"Received UnregisterShuffle request $requestId,
$applicationId, $shuffleId")
+ checkAuth(context, applicationId)
executeWithLeaderChecker(
context,
handleUnregisterShuffle(context, applicationId, shuffleId, requestId))
@@ -460,6 +463,7 @@ private[celeborn] class Master(
case ApplicationLost(appId, requestId) =>
logDebug(
s"Received ApplicationLost request $requestId, $appId from
${context.senderAddress}.")
+ checkAuth(context, appId)
executeWithLeaderChecker(context, handleApplicationLost(context, appId,
requestId))
case HeartbeatFromWorker(
@@ -542,6 +546,7 @@ private[celeborn] class Master(
context))
case pb: PbApplicationMetaRequest =>
+ // This request is from a worker
executeWithLeaderChecker(context,
handleRequestForApplicationMeta(context, pb))
}
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
index ed3df297d..46e440df2 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
@@ -89,6 +89,7 @@ private[deploy] class Controller(
userIdentifier,
pushDataTimeout,
partitionSplitEnabled) =>
+ checkAuth(context, applicationId)
val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId)
workerSource.sample(WorkerSource.RESERVE_SLOTS_TIME, shuffleKey) {
logDebug(s"Received ReserveSlots request, $shuffleKey, " +
@@ -118,6 +119,7 @@ private[deploy] class Controller(
mapAttempts,
epoch,
mockFailure) =>
+ checkAuth(context, applicationId)
val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId)
logDebug(s"Received CommitFiles request, $shuffleKey, primary files" +
s" ${primaryIds.asScala.mkString(",")}; replica files
${replicaIds.asScala.mkString(",")}.")
@@ -135,6 +137,7 @@ private[deploy] class Controller(
s"$commitFilesTimeMs ms.")
case DestroyWorkerSlots(shuffleKey, primaryLocations, replicaLocations,
mockFailure) =>
+ checkAuth(context, Utils.splitShuffleKey(shuffleKey)._1)
handleDestroy(context, shuffleKey, primaryLocations, replicaLocations,
mockFailure)
}