This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 6153ba4c6 [CELEBORN-1360] Ensure that a client cannot push or fetch 
data belonging to a different application
6153ba4c6 is described below

commit 6153ba4c627de4bdeedcb4f1e4ae45f18ad2965b
Author: Chandni Singh <[email protected]>
AuthorDate: Wed Apr 3 09:56:40 2024 +0800

    [CELEBORN-1360] Ensure that a client cannot push or fetch data belonging to 
a different application
    
    ### What changes were proposed in this pull request?
    This ensures that an authenticated client is not trying to push or fetch 
data which belongs to another application.
    
    ### Why are the changes needed?
    The changes are needed for authentication support.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    
    Closes #2431 from otterc/CELEBORN-1360.
    
    Authored-by: Chandni Singh <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../apache/celeborn/common/network/client/TransportClient.java    | 3 ++-
 .../apache/celeborn/common/network/sasl/CelebornSaslServer.java   | 7 ++++++-
 .../org/apache/celeborn/common/network/sasl/SaslRpcHandler.java   | 2 +-
 .../network/sasl/registration/RegistrationClientBootstrap.java    | 1 +
 .../common/network/sasl/registration/RegistrationRpcHandler.java  | 1 +
 .../apache/celeborn/common/network/server/BaseMessageHandler.java | 8 ++++++++
 .../apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java   | 8 ++++++--
 .../org/apache/celeborn/service/deploy/worker/FetchHandler.scala  | 3 ++-
 .../apache/celeborn/service/deploy/worker/PushDataHandler.scala   | 5 +++++
 .../apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java | 2 +-
 .../deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java    | 2 +-
 11 files changed, 34 insertions(+), 8 deletions(-)

diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
index f4b62d872..fd64b6bd0 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
@@ -354,7 +354,8 @@ public class TransportClient implements Closeable {
    * <p>Trying to set a different client ID after it's been set will result in 
an exception.
    */
   public void setClientId(String id) {
-    Preconditions.checkState(clientId == null, "Client ID has already been 
set.");
+    Preconditions.checkState(
+        clientId == null || clientId.equals(id), "Client ID has already been 
set.");
     this.clientId = id;
   }
 
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/CelebornSaslServer.java
 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/CelebornSaslServer.java
index efb9af5a5..05c05744c 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/CelebornSaslServer.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/CelebornSaslServer.java
@@ -37,6 +37,8 @@ import com.google.common.base.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.celeborn.common.network.client.TransportClient;
+
 /**
  * A SASL Server for Celeborn which simply keeps track of the state of a 
single SASL session, from
  * the initial state to the "authenticated" state. (It is not a server in the 
sense of accepting
@@ -106,6 +108,7 @@ public class CelebornSaslServer {
    */
   static class DigestCallbackHandler implements CallbackHandler {
     private final SecretRegistry secretRegistry;
+    private final TransportClient client;
 
     /**
      * The use of 'volatile' is not necessary here because the 'handle' 
invocation includes both the
@@ -114,7 +117,8 @@ public class CelebornSaslServer {
      */
     private String userName = null;
 
-    DigestCallbackHandler(SecretRegistry secretRegistry) {
+    DigestCallbackHandler(TransportClient client, SecretRegistry 
secretRegistry) {
+      this.client = Preconditions.checkNotNull(client);
       this.secretRegistry = Preconditions.checkNotNull(secretRegistry);
     }
 
@@ -129,6 +133,7 @@ public class CelebornSaslServer {
             throw new SaslException("No username provided by client");
           }
           userName = decodeIdentifier(encodedName);
+          client.setClientId(userName);
         } else if (callback instanceof PasswordCallback) {
           logger.trace("SASL server callback: setting password");
           PasswordCallback pc = (PasswordCallback) callback;
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslRpcHandler.java
 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslRpcHandler.java
index 127b3e772..0033372ea 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslRpcHandler.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslRpcHandler.java
@@ -94,7 +94,7 @@ public class SaslRpcHandler extends AbstractAuthRpcHandler {
             new CelebornSaslServer(
                 DIGEST_MD5,
                 DEFAULT_SASL_SERVER_PROPS,
-                new CelebornSaslServer.DigestCallbackHandler(secretRegistry));
+                new CelebornSaslServer.DigestCallbackHandler(client, 
secretRegistry));
       }
       byte[] response = 
saslServer.response(saslMessage.getPayload().toByteArray());
       callback.onSuccess(ByteBuffer.wrap(response));
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationClientBootstrap.java
 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationClientBootstrap.java
index 3604e927c..a6aa2ba88 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationClientBootstrap.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationClientBootstrap.java
@@ -112,6 +112,7 @@ public class RegistrationClientBootstrap implements 
TransportClientBootstrap {
       register(client);
       LOG.info("Registration for {}", appId);
       
registrationInfo.setRegistrationState(RegistrationInfo.RegistrationState.REGISTERED);
+      client.setClientId(appId);
     } catch (IOException | CelebornException e) {
       throw new RuntimeException(e);
     } finally {
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationRpcHandler.java
 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationRpcHandler.java
index 10803f95d..c0b25e6f9 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationRpcHandler.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationRpcHandler.java
@@ -181,6 +181,7 @@ public class RegistrationRpcHandler extends 
BaseMessageHandler {
         LOG.trace("Application registration started {}", 
registerApplicationRequest.getId());
         processRegisterApplicationRequest(registerApplicationRequest, 
callback);
         registrationState = RegistrationState.REGISTERED;
+        client.setClientId(registerApplicationRequest.getId());
         LOG.info(
             "Application registered: appId {} rpcId {}",
             registerApplicationRequest.getId(),
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
 
b/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
index d975dc482..974166745 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
@@ -46,4 +46,12 @@ public class BaseMessageHandler {
   public void channelInactive(TransportClient client) {}
 
   public void exceptionCaught(Throwable cause, TransportClient client) {}
+
+  protected void checkAuth(TransportClient client, String appId) {
+    if (client.getClientId() != null && !client.getClientId().equals(appId)) {
+      throw new SecurityException(
+          String.format(
+              "Client for %s not authorized for application %s.", 
client.getClientId(), appId));
+    }
+  }
 }
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java
 
b/common/src/test/java/org/apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java
index b97da2e86..1a24a8ddc 100644
--- 
a/common/src/test/java/org/apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java
@@ -24,6 +24,7 @@ import static org.mockito.Mockito.*;
 import org.junit.Test;
 
 import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.server.BaseMessageHandler;
 import org.apache.celeborn.common.network.util.TransportConf;
 
@@ -39,11 +40,12 @@ public class CelebornSaslSuiteJ extends SaslTestBase {
             DIGEST_MD5,
             DEFAULT_SASL_CLIENT_PROPS,
             new CelebornSaslClient.ClientCallbackHandler(TEST_USER, 
TEST_SECRET));
+    TransportClient transportClient = mock(TransportClient.class);
     CelebornSaslServer server =
         new CelebornSaslServer(
             DIGEST_MD5,
             DEFAULT_SASL_SERVER_PROPS,
-            new CelebornSaslServer.DigestCallbackHandler(secretRegistry));
+            new CelebornSaslServer.DigestCallbackHandler(transportClient, 
secretRegistry));
 
     assertFalse(client.isComplete());
     assertFalse(server.isComplete());
@@ -54,6 +56,7 @@ public class CelebornSaslSuiteJ extends SaslTestBase {
       clientMessage = client.response(server.response(clientMessage));
     }
     assertTrue(server.isComplete());
+    verify(transportClient, times(1)).setClientId(TEST_USER);
 
     // Disposal should invalidate
     server.dispose();
@@ -73,7 +76,8 @@ public class CelebornSaslSuiteJ extends SaslTestBase {
         new CelebornSaslServer(
             DIGEST_MD5,
             DEFAULT_SASL_SERVER_PROPS,
-            new CelebornSaslServer.DigestCallbackHandler(secretRegistry));
+            new CelebornSaslServer.DigestCallbackHandler(
+                mock(TransportClient.class), secretRegistry));
 
     assertFalse(client.isComplete());
     assertFalse(server.isComplete());
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index 114f50f2b..e9b1a6607 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -147,7 +147,7 @@ class FetchHandler(
         val endIndices = openStreamList.getEndIndexList
         val readLocalFlags = openStreamList.getReadLocalShuffleList
         val pbOpenStreamListResponse = PbOpenStreamListResponse.newBuilder()
-
+        checkAuth(client, Utils.splitShuffleKey(shuffleKey)._1)
         0 until files.size() foreach { idx =>
           val pbStreamHandlerOpt = handleReduceOpenStreamInternal(
             client,
@@ -319,6 +319,7 @@ class FetchHandler(
       isLegacy: Boolean,
       readLocalShuffle: Boolean = false,
       callback: RpcResponseCallback): Unit = {
+    checkAuth(client, Utils.splitShuffleKey(shuffleKey)._1)
     workerSource.recordAppActiveConnection(client, shuffleKey)
     workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME, shuffleKey)
     try {
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index 356b0d4cd..be5431fa2 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -119,6 +119,7 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
           client,
           pushData,
           pushData.requestId,
+          pushData.shuffleKey,
           () => {
             val partitionType =
               shufflePartitionType.getOrDefault(pushData.shuffleKey, 
PartitionType.REDUCE)
@@ -143,6 +144,7 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
           client,
           pushMergedData,
           pushMergedData.requestId,
+          pushMergedData.shuffleKey,
           () =>
             handlePushMergedData(
               pushMergedData,
@@ -748,8 +750,10 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
       client: TransportClient,
       message: RequestMessage,
       requestId: Long,
+      shuffleKey: String,
       handler: () => Unit,
       callback: RpcResponseCallback): Unit = {
+    checkAuth(client, Utils.splitShuffleKey(shuffleKey)._1)
     try {
       handler()
     } catch {
@@ -843,6 +847,7 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
       client,
       rpcRequest,
       requestId,
+      shuffleKey,
       () =>
         handleMapPartitionRpcRequestCore(
           requestId,
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
index 3bc4aabae..ef4f172bf 100644
--- 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
@@ -323,7 +323,7 @@ public class FetchHandlerSuiteJ {
     return fetchHandler;
   }
 
-  private final String shuffleKey = "dummyShuffleKey";
+  private final String shuffleKey = "dummyShuffleKey-123";
   private final String fileName = "dummyFileName";
   private final long dummyRequestId = 0;
 
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
index e6b688b43..1271d2c45 100644
--- 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
@@ -206,7 +206,7 @@ public class ReducePartitionDataWriterSuiteJ {
         new TransportMessage(
             MessageType.OPEN_STREAM,
             PbOpenStream.newBuilder()
-                .setShuffleKey("shuffleKey")
+                .setShuffleKey("shuffleKey-123")
                 .setFileName("location")
                 .setStartIndex(0)
                 .setEndIndex(Integer.MAX_VALUE)

Reply via email to