This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 6153ba4c6 [CELEBORN-1360] Ensure that a client cannot push or fetch
data belonging to a different application
6153ba4c6 is described below
commit 6153ba4c627de4bdeedcb4f1e4ae45f18ad2965b
Author: Chandni Singh <[email protected]>
AuthorDate: Wed Apr 3 09:56:40 2024 +0800
[CELEBORN-1360] Ensure that a client cannot push or fetch data belonging to
a different application
### What changes were proposed in this pull request?
This ensures that an authenticated client is not trying to push or fetch
data which belongs to another application.
### Why are the changes needed?
The changes are needed for authentication support.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Closes #2431 from otterc/CELEBORN-1360.
Authored-by: Chandni Singh <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../apache/celeborn/common/network/client/TransportClient.java | 3 ++-
.../apache/celeborn/common/network/sasl/CelebornSaslServer.java | 7 ++++++-
.../org/apache/celeborn/common/network/sasl/SaslRpcHandler.java | 2 +-
.../network/sasl/registration/RegistrationClientBootstrap.java | 1 +
.../common/network/sasl/registration/RegistrationRpcHandler.java | 1 +
.../apache/celeborn/common/network/server/BaseMessageHandler.java | 8 ++++++++
.../apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java | 8 ++++++--
.../org/apache/celeborn/service/deploy/worker/FetchHandler.scala | 3 ++-
.../apache/celeborn/service/deploy/worker/PushDataHandler.scala | 5 +++++
.../apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java | 2 +-
.../deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java | 2 +-
11 files changed, 34 insertions(+), 8 deletions(-)
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
index f4b62d872..fd64b6bd0 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
@@ -354,7 +354,8 @@ public class TransportClient implements Closeable {
* <p>Trying to set a different client ID after it's been set will result in
an exception.
*/
public void setClientId(String id) {
- Preconditions.checkState(clientId == null, "Client ID has already been
set.");
+ Preconditions.checkState(
+ clientId == null || clientId.equals(id), "Client ID has already been
set.");
this.clientId = id;
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/sasl/CelebornSaslServer.java
b/common/src/main/java/org/apache/celeborn/common/network/sasl/CelebornSaslServer.java
index efb9af5a5..05c05744c 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/sasl/CelebornSaslServer.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/sasl/CelebornSaslServer.java
@@ -37,6 +37,8 @@ import com.google.common.base.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.celeborn.common.network.client.TransportClient;
+
/**
* A SASL Server for Celeborn which simply keeps track of the state of a
single SASL session, from
* the initial state to the "authenticated" state. (It is not a server in the
sense of accepting
@@ -106,6 +108,7 @@ public class CelebornSaslServer {
*/
static class DigestCallbackHandler implements CallbackHandler {
private final SecretRegistry secretRegistry;
+ private final TransportClient client;
/**
* The use of 'volatile' is not necessary here because the 'handle'
invocation includes both the
@@ -114,7 +117,8 @@ public class CelebornSaslServer {
*/
private String userName = null;
- DigestCallbackHandler(SecretRegistry secretRegistry) {
+ DigestCallbackHandler(TransportClient client, SecretRegistry
secretRegistry) {
+ this.client = Preconditions.checkNotNull(client);
this.secretRegistry = Preconditions.checkNotNull(secretRegistry);
}
@@ -129,6 +133,7 @@ public class CelebornSaslServer {
throw new SaslException("No username provided by client");
}
userName = decodeIdentifier(encodedName);
+ client.setClientId(userName);
} else if (callback instanceof PasswordCallback) {
logger.trace("SASL server callback: setting password");
PasswordCallback pc = (PasswordCallback) callback;
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslRpcHandler.java
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslRpcHandler.java
index 127b3e772..0033372ea 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslRpcHandler.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslRpcHandler.java
@@ -94,7 +94,7 @@ public class SaslRpcHandler extends AbstractAuthRpcHandler {
new CelebornSaslServer(
DIGEST_MD5,
DEFAULT_SASL_SERVER_PROPS,
- new CelebornSaslServer.DigestCallbackHandler(secretRegistry));
+ new CelebornSaslServer.DigestCallbackHandler(client,
secretRegistry));
}
byte[] response =
saslServer.response(saslMessage.getPayload().toByteArray());
callback.onSuccess(ByteBuffer.wrap(response));
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationClientBootstrap.java
b/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationClientBootstrap.java
index 3604e927c..a6aa2ba88 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationClientBootstrap.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationClientBootstrap.java
@@ -112,6 +112,7 @@ public class RegistrationClientBootstrap implements
TransportClientBootstrap {
register(client);
LOG.info("Registration for {}", appId);
registrationInfo.setRegistrationState(RegistrationInfo.RegistrationState.REGISTERED);
+ client.setClientId(appId);
} catch (IOException | CelebornException e) {
throw new RuntimeException(e);
} finally {
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationRpcHandler.java
b/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationRpcHandler.java
index 10803f95d..c0b25e6f9 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationRpcHandler.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/sasl/registration/RegistrationRpcHandler.java
@@ -181,6 +181,7 @@ public class RegistrationRpcHandler extends
BaseMessageHandler {
LOG.trace("Application registration started {}",
registerApplicationRequest.getId());
processRegisterApplicationRequest(registerApplicationRequest,
callback);
registrationState = RegistrationState.REGISTERED;
+ client.setClientId(registerApplicationRequest.getId());
LOG.info(
"Application registered: appId {} rpcId {}",
registerApplicationRequest.getId(),
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
b/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
index d975dc482..974166745 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
@@ -46,4 +46,12 @@ public class BaseMessageHandler {
public void channelInactive(TransportClient client) {}
public void exceptionCaught(Throwable cause, TransportClient client) {}
+
+ protected void checkAuth(TransportClient client, String appId) {
+ if (client.getClientId() != null && !client.getClientId().equals(appId)) {
+ throw new SecurityException(
+ String.format(
+ "Client for %s not authorized for application %s.",
client.getClientId(), appId));
+ }
+ }
}
diff --git
a/common/src/test/java/org/apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java
b/common/src/test/java/org/apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java
index b97da2e86..1a24a8ddc 100644
---
a/common/src/test/java/org/apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java
+++
b/common/src/test/java/org/apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java
@@ -24,6 +24,7 @@ import static org.mockito.Mockito.*;
import org.junit.Test;
import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.network.util.TransportConf;
@@ -39,11 +40,12 @@ public class CelebornSaslSuiteJ extends SaslTestBase {
DIGEST_MD5,
DEFAULT_SASL_CLIENT_PROPS,
new CelebornSaslClient.ClientCallbackHandler(TEST_USER,
TEST_SECRET));
+ TransportClient transportClient = mock(TransportClient.class);
CelebornSaslServer server =
new CelebornSaslServer(
DIGEST_MD5,
DEFAULT_SASL_SERVER_PROPS,
- new CelebornSaslServer.DigestCallbackHandler(secretRegistry));
+ new CelebornSaslServer.DigestCallbackHandler(transportClient,
secretRegistry));
assertFalse(client.isComplete());
assertFalse(server.isComplete());
@@ -54,6 +56,7 @@ public class CelebornSaslSuiteJ extends SaslTestBase {
clientMessage = client.response(server.response(clientMessage));
}
assertTrue(server.isComplete());
+ verify(transportClient, times(1)).setClientId(TEST_USER);
// Disposal should invalidate
server.dispose();
@@ -73,7 +76,8 @@ public class CelebornSaslSuiteJ extends SaslTestBase {
new CelebornSaslServer(
DIGEST_MD5,
DEFAULT_SASL_SERVER_PROPS,
- new CelebornSaslServer.DigestCallbackHandler(secretRegistry));
+ new CelebornSaslServer.DigestCallbackHandler(
+ mock(TransportClient.class), secretRegistry));
assertFalse(client.isComplete());
assertFalse(server.isComplete());
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index 114f50f2b..e9b1a6607 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -147,7 +147,7 @@ class FetchHandler(
val endIndices = openStreamList.getEndIndexList
val readLocalFlags = openStreamList.getReadLocalShuffleList
val pbOpenStreamListResponse = PbOpenStreamListResponse.newBuilder()
-
+ checkAuth(client, Utils.splitShuffleKey(shuffleKey)._1)
0 until files.size() foreach { idx =>
val pbStreamHandlerOpt = handleReduceOpenStreamInternal(
client,
@@ -319,6 +319,7 @@ class FetchHandler(
isLegacy: Boolean,
readLocalShuffle: Boolean = false,
callback: RpcResponseCallback): Unit = {
+ checkAuth(client, Utils.splitShuffleKey(shuffleKey)._1)
workerSource.recordAppActiveConnection(client, shuffleKey)
workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME, shuffleKey)
try {
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index 356b0d4cd..be5431fa2 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -119,6 +119,7 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
client,
pushData,
pushData.requestId,
+ pushData.shuffleKey,
() => {
val partitionType =
shufflePartitionType.getOrDefault(pushData.shuffleKey,
PartitionType.REDUCE)
@@ -143,6 +144,7 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
client,
pushMergedData,
pushMergedData.requestId,
+ pushMergedData.shuffleKey,
() =>
handlePushMergedData(
pushMergedData,
@@ -748,8 +750,10 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
client: TransportClient,
message: RequestMessage,
requestId: Long,
+ shuffleKey: String,
handler: () => Unit,
callback: RpcResponseCallback): Unit = {
+ checkAuth(client, Utils.splitShuffleKey(shuffleKey)._1)
try {
handler()
} catch {
@@ -843,6 +847,7 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
client,
rpcRequest,
requestId,
+ shuffleKey,
() =>
handleMapPartitionRpcRequestCore(
requestId,
diff --git
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
index 3bc4aabae..ef4f172bf 100644
---
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
+++
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
@@ -323,7 +323,7 @@ public class FetchHandlerSuiteJ {
return fetchHandler;
}
- private final String shuffleKey = "dummyShuffleKey";
+ private final String shuffleKey = "dummyShuffleKey-123";
private final String fileName = "dummyFileName";
private final long dummyRequestId = 0;
diff --git
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
index e6b688b43..1271d2c45 100644
---
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
+++
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
@@ -206,7 +206,7 @@ public class ReducePartitionDataWriterSuiteJ {
new TransportMessage(
MessageType.OPEN_STREAM,
PbOpenStream.newBuilder()
- .setShuffleKey("shuffleKey")
+ .setShuffleKey("shuffleKey-123")
.setFileName("location")
.setStartIndex(0)
.setEndIndex(Integer.MAX_VALUE)