This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 0d72c9595 [CELEBORN-1365] Ensure that a client cannot update the 
metadata belonging to a different application
0d72c9595 is described below

commit 0d72c95958d32e138d136b3dc8e62046a6c08cad
Author: Chandni Singh <[email protected]>
AuthorDate: Mon Apr 8 10:35:13 2024 +0800

    [CELEBORN-1365] Ensure that a client cannot update the metadata belonging 
to a different application
    
    ### What changes were proposed in this pull request?
    This ensures that an authenticated client does not update the metadata 
belonging to another application.
    
    ### Why are the changes needed?
    The changes are needed for authentication support.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    
    Closes #2441 from otterc/CELEBORN-1365.
    
    Authored-by: Chandni Singh <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../common/network/server/BaseMessageHandler.java       |  2 +-
 .../org/apache/celeborn/common/rpc/RpcEndpoint.scala    | 17 +++++++++++++++++
 .../apache/celeborn/common/rpc/netty/Dispatcher.scala   |  9 ++++++---
 .../celeborn/common/rpc/netty/NettyRpcCallContext.scala |  5 +++--
 .../apache/celeborn/common/rpc/netty/NettyRpcEnv.scala  |  2 +-
 .../apache/celeborn/service/deploy/master/Master.scala  |  5 +++++
 .../celeborn/service/deploy/worker/Controller.scala     |  3 +++
 7 files changed, 36 insertions(+), 7 deletions(-)

diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
 
b/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
index 974166745..afc869de5 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
@@ -49,7 +49,7 @@ public class BaseMessageHandler {
 
   protected void checkAuth(TransportClient client, String appId) {
     if (client.getClientId() != null && !client.getClientId().equals(appId)) {
-      throw new SecurityException(
+      throw new IllegalStateException(
           String.format(
               "Client for %s not authorized for application %s.", 
client.getClientId(), appId));
     }
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpoint.scala 
b/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpoint.scala
index cf69c8242..8d12c227a 100644
--- a/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpoint.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpoint.scala
@@ -18,6 +18,8 @@
 package org.apache.celeborn.common.rpc
 
 import org.apache.celeborn.common.exception.CelebornException
+import org.apache.celeborn.common.network.client.TransportClient
+import org.apache.celeborn.common.rpc.netty.RemoteNettyRpcCallContext
 
 /**
  * A factory class to create the [[RpcEnv]]. It must have an empty constructor 
so that it can be
@@ -134,6 +136,21 @@ trait RpcEndpoint {
       rpcEnv.stop(_self)
     }
   }
+
+  def checkAuth(context: RpcCallContext, appId: String): Unit = {
+    context match {
+      case remoteContext: RemoteNettyRpcCallContext =>
+        checkAuth(remoteContext.transportClient, appId)
+      case _ =>
+      // Do nothing if the context is not RemoteNettyRpcCallContext
+    }
+  }
+
+  private def checkAuth(client: TransportClient, appId: String): Unit = {
+    if (client.getClientId != null && client.getClientId != appId)
+      throw new IllegalStateException(
+        s"Client for ${client.getClientId} not authorized for application 
$appId.")
+  }
 }
 
 /**
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala
index b137ae2f5..391b64186 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala
@@ -26,7 +26,7 @@ import scala.util.control.NonFatal
 
 import org.apache.celeborn.common.exception.CelebornException
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.network.client.RpcResponseCallback
+import org.apache.celeborn.common.network.client.{RpcResponseCallback, 
TransportClient}
 import org.apache.celeborn.common.rpc._
 import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils}
 
@@ -120,9 +120,12 @@ private[celeborn] class Dispatcher(nettyEnv: NettyRpcEnv) 
extends Logging {
   }
 
   /** Posts a message sent by a remote endpoint. */
-  def postRemoteMessage(message: RequestMessage, callback: 
RpcResponseCallback): Unit = {
+  def postRemoteMessage(
+      message: RequestMessage,
+      callback: RpcResponseCallback,
+      client: TransportClient): Unit = {
     val rpcCallContext =
-      new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
+      new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress, 
client)
     val rpcMessage = RpcMessage(message.senderAddress, message.content, 
rpcCallContext)
     postMessage(message.receiver.name, rpcMessage, (e) => 
callback.onFailure(e))
   }
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcCallContext.scala
 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcCallContext.scala
index 82a6f8547..3a51e7c74 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcCallContext.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcCallContext.scala
@@ -20,7 +20,7 @@ package org.apache.celeborn.common.rpc.netty
 import scala.concurrent.Promise
 
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.network.client.RpcResponseCallback
+import org.apache.celeborn.common.network.client.{RpcResponseCallback, 
TransportClient}
 import org.apache.celeborn.common.rpc.{RpcAddress, RpcCallContext}
 
 abstract private[celeborn] class NettyRpcCallContext(override val 
senderAddress: RpcAddress)
@@ -57,7 +57,8 @@ private[celeborn] class LocalNettyRpcCallContext(
 private[celeborn] class RemoteNettyRpcCallContext(
     val nettyEnv: NettyRpcEnv,
     val callback: RpcResponseCallback,
-    senderAddress: RpcAddress)
+    senderAddress: RpcAddress,
+    val transportClient: TransportClient)
   extends NettyRpcCallContext(senderAddress) {
 
   override protected def send(message: Any): Unit = {
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
index 824935051..b0f470efe 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
@@ -585,7 +585,7 @@ private[celeborn] class NettyRpcHandler(
     try {
       val message = requestMessage.body().nioByteBuffer()
       val messageToDispatch = internalReceive(client, message)
-      dispatcher.postRemoteMessage(messageToDispatch, callback)
+      dispatcher.postRemoteMessage(messageToDispatch, callback, client)
     } catch {
       case e: Exception =>
         val rpcReq = requestMessage.asInstanceOf[RpcRequest]
diff --git 
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala 
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
index 05ae53cc3..736cc7e45 100644
--- 
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
+++ 
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
@@ -397,6 +397,7 @@ private[celeborn] class Master(
           requestId,
           shouldResponse) =>
       logDebug(s"Received heartbeat from app $appId")
+      checkAuth(context, appId)
       executeWithLeaderChecker(
         context,
         handleHeartbeatFromApplication(
@@ -446,6 +447,7 @@ private[celeborn] class Master(
 
     case requestSlots @ RequestSlots(applicationId, _, _, _, _, _, _, _, _, _, 
_) =>
       logTrace(s"Received RequestSlots request $requestSlots.")
+      checkAuth(context, applicationId)
       executeWithLeaderChecker(context, handleRequestSlots(context, 
requestSlots))
 
     case pb: PbUnregisterShuffle =>
@@ -453,6 +455,7 @@ private[celeborn] class Master(
       val shuffleId = pb.getShuffleId
       val requestId = pb.getRequestId
       logDebug(s"Received UnregisterShuffle request $requestId, 
$applicationId, $shuffleId")
+      checkAuth(context, applicationId)
       executeWithLeaderChecker(
         context,
         handleUnregisterShuffle(context, applicationId, shuffleId, requestId))
@@ -460,6 +463,7 @@ private[celeborn] class Master(
     case ApplicationLost(appId, requestId) =>
       logDebug(
         s"Received ApplicationLost request $requestId, $appId from 
${context.senderAddress}.")
+      checkAuth(context, appId)
       executeWithLeaderChecker(context, handleApplicationLost(context, appId, 
requestId))
 
     case HeartbeatFromWorker(
@@ -542,6 +546,7 @@ private[celeborn] class Master(
           context))
 
     case pb: PbApplicationMetaRequest =>
+      // This request is from a worker
       executeWithLeaderChecker(context, 
handleRequestForApplicationMeta(context, pb))
   }
 
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
index ed3df297d..46e440df2 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
@@ -89,6 +89,7 @@ private[deploy] class Controller(
           userIdentifier,
           pushDataTimeout,
           partitionSplitEnabled) =>
+      checkAuth(context, applicationId)
       val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId)
       workerSource.sample(WorkerSource.RESERVE_SLOTS_TIME, shuffleKey) {
         logDebug(s"Received ReserveSlots request, $shuffleKey, " +
@@ -118,6 +119,7 @@ private[deploy] class Controller(
           mapAttempts,
           epoch,
           mockFailure) =>
+      checkAuth(context, applicationId)
       val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId)
       logDebug(s"Received CommitFiles request, $shuffleKey, primary files" +
         s" ${primaryIds.asScala.mkString(",")}; replica files 
${replicaIds.asScala.mkString(",")}.")
@@ -135,6 +137,7 @@ private[deploy] class Controller(
         s"$commitFilesTimeMs ms.")
 
     case DestroyWorkerSlots(shuffleKey, primaryLocations, replicaLocations, 
mockFailure) =>
+      checkAuth(context, Utils.splitShuffleKey(shuffleKey)._1)
       handleDestroy(context, shuffleKey, primaryLocations, replicaLocations, 
mockFailure)
   }
 

Reply via email to