This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new b09febdd8 [CELEBORN-1176] Server side support for Sasl Auth
b09febdd8 is described below
commit b09febdd8c5f97e370c4cc46c776f1a9bc61e3c5
Author: Chandni Singh <[email protected]>
AuthorDate: Mon Dec 18 11:27:28 2023 +0800
[CELEBORN-1176] Server side support for Sasl Auth
### What changes were proposed in this pull request?
This adds the server side Sasl authentication support in the transport
layer. Most of this code is taken from Apache Spark.
### Why are the changes needed?
The changes are needed for adding authentication to Celeborn. See
[CELEBORN-1011](https://issues.apache.org/jira/browse/CELEBORN-1011).
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added UTs.
Closes #2164 from otterc/CELEBORN-1176.
Authored-by: Chandni Singh <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
---
.../celeborn/common/network/TransportContext.java | 4 +
.../common/network/sasl/CelebornSaslServer.java | 159 ++++++++++++++++
.../common/network/sasl/SaslRpcHandler.java | 131 +++++++++++++
.../common/network/sasl/SaslServerBootstrap.java | 49 +++++
.../celeborn/common/network/sasl/SaslUtils.java | 13 ++
.../common/network/sasl/SecretRegistry.java | 27 +++
.../common/network/sasl/SecretRegistryImpl.java | 53 ++++++
.../network/server/AbstractAuthRpcHandler.java | 89 +++++++++
.../common/network/sasl/CelebornSaslSuiteJ.java | 207 +++++++++++++++++++++
9 files changed, 732 insertions(+)
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
index c8796ea32..86ec779f2 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
@@ -115,6 +115,10 @@ public class TransportContext {
return new TransportServer(this, host, port, source, msgHandler,
bootstraps);
}
+ public TransportServer createServer(List<TransportServerBootstrap>
bootstraps) {
+ return createServer(null, 0, bootstraps);
+ }
+
public TransportServer createServer(int port) {
return createServer(null, port, Collections.emptyList());
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/sasl/CelebornSaslServer.java
b/common/src/main/java/org/apache/celeborn/common/network/sasl/CelebornSaslServer.java
new file mode 100644
index 000000000..7248a3389
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/sasl/CelebornSaslServer.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.sasl;
+
+import static org.apache.celeborn.common.network.sasl.SaslUtils.*;
+
+import java.util.Map;
+
+import javax.annotation.Nullable;
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.AuthorizeCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+
+import com.google.common.base.Preconditions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A SASL Server for Celeborn which simply keeps track of the state of a
single SASL session, from
+ * the initial state to the "authenticated" state. (It is not a server in the
sense of accepting
+ * connections on some socket.)
+ */
+public class CelebornSaslServer {
+ private static final Logger logger =
LoggerFactory.getLogger(CelebornSaslServer.class);
+
+ private SaslServer saslServer;
+
+ public CelebornSaslServer(
+ String saslMechanism,
+ @Nullable Map<String, String> saslProps,
+ @Nullable CallbackHandler callbackHandler) {
+ Preconditions.checkNotNull(saslMechanism);
+ try {
+ this.saslServer =
+ Sasl.createSaslServer(saslMechanism, null, DEFAULT_REALM, saslProps,
callbackHandler);
+ } catch (SaslException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+ /** Determines whether the authentication exchange has completed
successfully. */
+ public synchronized boolean isComplete() {
+ return saslServer != null && saslServer.isComplete();
+ }
+
+ /** Returns the value of a negotiated property. */
+ public synchronized Object getNegotiatedProperty(String name) {
+ return saslServer.getNegotiatedProperty(name);
+ }
+
+ /**
+ * Used to respond to server SASL tokens.
+ *
+ * @param token Server's SASL token
+ * @return response to send back to the server.
+ */
+ public synchronized byte[] response(byte[] token) {
+ try {
+ return saslServer != null ? saslServer.evaluateResponse(token) :
EMPTY_BYTE_ARRAY;
+ } catch (SaslException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
SaslServer might be
+ * using.
+ */
+ public synchronized void dispose() {
+ if (saslServer != null) {
+ try {
+ saslServer.dispose();
+ } catch (SaslException e) {
+ // ignore
+ } finally {
+ saslServer = null;
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler for SASL
DIGEST-MD5 mechanism.
+ */
+ static class DigestCallbackHandler implements CallbackHandler {
+ private final SecretRegistry secretKeyHolder;
+
+ /**
+ * The use of 'volatile' is not necessary here because the 'handle'
invocation includes both the
+ * NameCallback and PasswordCallback (with the name preceding the
password), all within the same
+ * thread.
+ */
+ private String userName = null;
+
+ DigestCallbackHandler(SecretRegistry secretRegistry) {
+ this.secretKeyHolder = Preconditions.checkNotNull(secretRegistry);
+ }
+
+ @Override
+ public void handle(Callback[] callbacks) throws
UnsupportedCallbackException, SaslException {
+ for (Callback callback : callbacks) {
+ if (callback instanceof NameCallback) {
+ logger.trace("SASL server callback: setting username");
+ NameCallback nc = (NameCallback) callback;
+ String encodedName = nc.getName() != null ? nc.getName() :
nc.getDefaultName();
+ if (encodedName == null) {
+ throw new SaslException("No username provided by client");
+ }
+ userName = decodeIdentifier(encodedName);
+ } else if (callback instanceof PasswordCallback) {
+ logger.trace("SASL server callback: setting password");
+ PasswordCallback pc = (PasswordCallback) callback;
+ String secret = secretKeyHolder.getSecretKey(userName);
+ if (secret == null) {
+ // TODO: CELEBORN-1179 Add support for fetching the secret from
the Celeborn master.
+ throw new RuntimeException("Registration information not found for
" + userName);
+ }
+ pc.setPassword(encodePassword(secret));
+ } else if (callback instanceof RealmCallback) {
+ logger.trace("SASL server callback: setting realm");
+ RealmCallback rc = (RealmCallback) callback;
+ rc.setText(rc.getDefaultText());
+ } else if (callback instanceof AuthorizeCallback) {
+ AuthorizeCallback ac = (AuthorizeCallback) callback;
+ String authId = ac.getAuthenticationID();
+ String authzId = ac.getAuthorizationID();
+ ac.setAuthorized(authId.equals(authzId));
+ if (ac.isAuthorized()) {
+ ac.setAuthorizedID(authzId);
+ }
+ logger.debug("SASL Authorization complete, authorized set to {}",
ac.isAuthorized());
+ } else {
+ throw new UnsupportedCallbackException(callback, "Unrecognized
callback: " + callback);
+ }
+ }
+ }
+ }
+}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslRpcHandler.java
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslRpcHandler.java
new file mode 100644
index 000000000..5bc72f0f1
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslRpcHandler.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.sasl;
+
+import static org.apache.celeborn.common.network.sasl.SaslUtils.*;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
+import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.protocol.RpcRequest;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import org.apache.celeborn.common.network.server.AbstractAuthRpcHandler;
+import org.apache.celeborn.common.network.server.BaseMessageHandler;
+import org.apache.celeborn.common.network.util.TransportConf;
+import org.apache.celeborn.common.protocol.PbSaslRequest;
+
+/**
+ * RPC Handler which performs SASL authentication before delegating to a child
RPC handler. The
+ * delegate will only receive messages if the given connection has been
successfully authenticated.
+ * A connection may be authenticated at most once.
+ *
+ * <p>Note that the authentication process consists of multiple
challenge-response pairs, each of
+ * which are individual RPCs.
+ */
+public class SaslRpcHandler extends AbstractAuthRpcHandler {
+ private static final Logger logger =
LoggerFactory.getLogger(SaslRpcHandler.class);
+
+ /** Transport configuration. */
+ private final TransportConf conf;
+
+ /** The client channel. */
+ private final Channel channel;
+
+ /** Class which provides secret keys which are shared by server and client
on a per-app basis. */
+ private final SecretRegistry secretRegistry;
+
+ private CelebornSaslServer saslServer;
+
+ public SaslRpcHandler(
+ TransportConf conf,
+ Channel channel,
+ BaseMessageHandler delegate,
+ SecretRegistry secretRegistry) {
+ super(delegate);
+ this.conf = conf;
+ this.channel = channel;
+ this.secretRegistry = secretRegistry;
+ this.saslServer = null;
+ }
+
+ @Override
+ public boolean checkRegistered() {
+ return delegate.checkRegistered();
+ }
+
+ @Override
+ public boolean doAuthChallenge(
+ TransportClient client, RequestMessage message, RpcResponseCallback
callback) {
+ if (saslServer == null || !saslServer.isComplete()) {
+ RpcRequest rpcRequest = (RpcRequest) message;
+ PbSaslRequest saslMessage;
+ try {
+ TransportMessage pbMsg =
TransportMessage.fromByteBuffer(message.body().nioByteBuffer());
+ saslMessage = pbMsg.getParsedPayload();
+ } catch (IOException e) {
+ logger.error("Error while parsing Sasl Message with RPC id {}",
rpcRequest.requestId, e);
+ callback.onFailure(e);
+ return false;
+ }
+ if (saslServer == null) {
+ saslServer =
+ new CelebornSaslServer(
+ DIGEST_MD5,
+ DEFAULT_SASL_SERVER_PROPS,
+ new CelebornSaslServer.DigestCallbackHandler(secretRegistry));
+ }
+ byte[] response =
saslServer.response(saslMessage.getPayload().toByteArray());
+ callback.onSuccess(ByteBuffer.wrap(response));
+ }
+ if (saslServer.isComplete()) {
+ logger.debug("SASL authentication successful for channel {}", client);
+ complete();
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public void channelInactive(TransportClient client) {
+ super.channelInactive(client);
+ cleanup();
+ }
+
+ private void complete() {
+ cleanup();
+ }
+
+ private void cleanup() {
+ if (null != saslServer) {
+ try {
+ saslServer.dispose();
+ } catch (RuntimeException e) {
+ logger.error("Error while disposing SASL server", e);
+ } finally {
+ saslServer = null;
+ }
+ }
+ }
+}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslServerBootstrap.java
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslServerBootstrap.java
new file mode 100644
index 000000000..44455f10e
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslServerBootstrap.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.sasl;
+
+import io.netty.channel.Channel;
+
+import org.apache.celeborn.common.network.server.BaseMessageHandler;
+import org.apache.celeborn.common.network.server.TransportServerBootstrap;
+import org.apache.celeborn.common.network.util.TransportConf;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a
client connects to the
+ * server. This allows customizing the client channel to allow for things such
as SASL
+ * authentication.
+ */
+public class SaslServerBootstrap implements TransportServerBootstrap {
+
+ private final TransportConf conf;
+ private final SecretRegistry secretKeyHolder;
+
+ public SaslServerBootstrap(TransportConf conf, SecretRegistry
secretRegistry) {
+ this.conf = conf;
+ this.secretKeyHolder = secretRegistry;
+ }
+
+ /**
+ * Wrap the given application handler in a SaslRpcHandler that will handle
the initial SASL
+ * negotiation.
+ */
+ @Override
+ public BaseMessageHandler doBootstrap(Channel channel, BaseMessageHandler
rpcHandler) {
+ return new SaslRpcHandler(conf, channel, rpcHandler, secretKeyHolder);
+ }
+}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslUtils.java
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslUtils.java
index 5d57057e4..75a20ee37 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslUtils.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslUtils.java
@@ -40,12 +40,25 @@ public class SaslUtils {
static final Map<String, String> DEFAULT_SASL_CLIENT_PROPS =
ImmutableMap.<String, String>builder().put(Sasl.QOP, QOP_AUTH).build();
+ static final Map<String, String> DEFAULT_SASL_SERVER_PROPS =
+ ImmutableMap.<String, String>builder()
+ .put(Sasl.SERVER_AUTH, "true")
+ .put(Sasl.QOP, QOP_AUTH)
+ .build();
+
/* Encode a byte[] identifier as a Base64-encoded string. */
static String encodeIdentifier(String identifier) {
Preconditions.checkNotNull(identifier, "User cannot be null if SASL is
enabled");
return
Base64.getEncoder().encodeToString(identifier.getBytes(StandardCharsets.UTF_8));
}
+ static String decodeIdentifier(String identifier) {
+ Preconditions.checkNotNull(identifier, "User cannot be null if SASL is
enabled");
+ return new String(
+
Base64.getDecoder().decode(identifier.getBytes(StandardCharsets.UTF_8)),
+ StandardCharsets.UTF_8);
+ }
+
/** Encode a password as a base64-encoded char[] array. */
static char[] encodePassword(String password) {
Preconditions.checkNotNull(password, "Password cannot be null if SASL is
enabled");
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SecretRegistry.java
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SecretRegistry.java
new file mode 100644
index 000000000..995e7af80
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SecretRegistry.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.sasl;
+
+/** Interface for getting a secret key associated with some application. */
+public interface SecretRegistry {
+
+ /** Gets an appropriate SASL secret key for the given appId. */
+ String getSecretKey(String appId);
+
+ boolean isRegistered(String appId);
+}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SecretRegistryImpl.java
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SecretRegistryImpl.java
new file mode 100644
index 000000000..402fc3a27
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SecretRegistryImpl.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.sasl;
+
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * A simple implementation of {@link SecretRegistry} that stores secrets in
memory. It is designed
+ * as a singleton.
+ */
+public class SecretRegistryImpl implements SecretRegistry {
+
+ private static final SecretRegistryImpl INSTANCE = new SecretRegistryImpl();
+
+ public static SecretRegistryImpl getInstance() {
+ return INSTANCE;
+ }
+
+ private final ConcurrentHashMap<String, String> secrets = new
ConcurrentHashMap<>();
+
+ public void register(String appId, String secret) {
+ secrets.put(appId, secret);
+ }
+
+ public void unregister(String appId) {
+ secrets.remove(appId);
+ }
+
+ public boolean isRegistered(String appId) {
+ return secrets.containsKey(appId);
+ }
+
+ /** Gets an appropriate SASL secret key for the given appId. */
+ @Override
+ public String getSecretKey(String appId) {
+ return secrets.get(appId);
+ }
+}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/server/AbstractAuthRpcHandler.java
b/common/src/main/java/org/apache/celeborn/common/network/server/AbstractAuthRpcHandler.java
new file mode 100644
index 000000000..36eb89691
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/server/AbstractAuthRpcHandler.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.server;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
+import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+
+/**
+ * RPC Handler which performs authentication, and when it's successful,
delegates further calls to
+ * another RPC handler. The authentication handshake itself should be
implemented by subclasses.
+ */
+public abstract class AbstractAuthRpcHandler extends BaseMessageHandler {
+ /** RpcHandler we will delegate to for authenticated connections. */
+ private static final Logger LOG =
LoggerFactory.getLogger(AbstractAuthRpcHandler.class);
+
+ protected final BaseMessageHandler delegate;
+
+ private boolean isAuthenticated;
+
+ protected AbstractAuthRpcHandler(BaseMessageHandler delegate) {
+ this.delegate = delegate;
+ }
+
+ /**
+ * Responds to an authentication challenge.
+ *
+ * @return Whether the client is authenticated.
+ */
+ protected abstract boolean doAuthChallenge(
+ TransportClient client, RequestMessage message, RpcResponseCallback
callback);
+
+ @Override
+ public final void receive(
+ TransportClient client, RequestMessage message, RpcResponseCallback
callback) {
+ if (isAuthenticated) {
+ LOG.trace("Already authenticated. Delegating {}", client.getClientId());
+ delegate.receive(client, message, callback);
+ } else {
+ isAuthenticated = doAuthChallenge(client, message, callback);
+ }
+ }
+
+ @Override
+ public final void receive(TransportClient client, RequestMessage message) {
+ if (isAuthenticated) {
+ delegate.receive(client, message);
+ } else {
+ throw new SecurityException("Unauthenticated call to receive().");
+ }
+ }
+
+ @Override
+ public void channelActive(TransportClient client) {
+ delegate.channelActive(client);
+ }
+
+ @Override
+ public void channelInactive(TransportClient client) {
+ delegate.channelInactive(client);
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause, TransportClient client) {
+ delegate.exceptionCaught(cause, client);
+ }
+
+ public boolean isAuthenticated() {
+ return isAuthenticated;
+ }
+}
diff --git
a/common/src/test/java/org/apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java
b/common/src/test/java/org/apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java
new file mode 100644
index 000000000..582988e1d
--- /dev/null
+++
b/common/src/test/java/org/apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.sasl;
+
+import static org.apache.celeborn.common.network.sasl.SaslUtils.*;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.TransportContext;
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
+import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.client.TransportClientBootstrap;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.server.BaseMessageHandler;
+import org.apache.celeborn.common.network.server.TransportServer;
+import org.apache.celeborn.common.network.util.TransportConf;
+import org.apache.celeborn.common.util.JavaUtils;
+
+/**
+ * Jointly tests {@link CelebornSaslClient} and {@link CelebornSaslServer}, as
both are black boxes.
+ */
+public class CelebornSaslSuiteJ {
+ private static final String TEST_USER = "appId";
+ private static final String TEST_SECRET = "secret";
+
+ @BeforeClass
+ public static void setup() {
+ SecretRegistryImpl.getInstance().register(TEST_USER, TEST_SECRET);
+ }
+
+ @Test
+ public void testDigestMatching() {
+ CelebornSaslClient client =
+ new CelebornSaslClient(
+ DIGEST_MD5,
+ DEFAULT_SASL_CLIENT_PROPS,
+ new CelebornSaslClient.ClientCallbackHandler(TEST_USER,
TEST_SECRET));
+ CelebornSaslServer server =
+ new CelebornSaslServer(
+ DIGEST_MD5,
+ DEFAULT_SASL_SERVER_PROPS,
+ new
CelebornSaslServer.DigestCallbackHandler(SecretRegistryImpl.getInstance()));
+
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+
+ byte[] clientMessage = client.firstToken();
+
+ while (!client.isComplete()) {
+ clientMessage = client.response(server.response(clientMessage));
+ }
+ assertTrue(server.isComplete());
+
+ // Disposal should invalidate
+ server.dispose();
+ assertFalse(server.isComplete());
+ client.dispose();
+ assertFalse(client.isComplete());
+ }
+
+ @Test
+ public void testDigestNonMatching() {
+ CelebornSaslClient client =
+ new CelebornSaslClient(
+ DIGEST_MD5,
+ DEFAULT_SASL_CLIENT_PROPS,
+ new CelebornSaslClient.ClientCallbackHandler(TEST_USER, "invalid"
+ TEST_SECRET));
+ CelebornSaslServer server =
+ new CelebornSaslServer(
+ DIGEST_MD5,
+ DEFAULT_SASL_SERVER_PROPS,
+ new
CelebornSaslServer.DigestCallbackHandler(SecretRegistryImpl.getInstance()));
+
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+
+ byte[] clientMessage = client.firstToken();
+
+ try {
+ while (!client.isComplete()) {
+ clientMessage = client.response(server.response(clientMessage));
+ }
+ fail("Should not have completed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage().contains("Mismatched response"));
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+ }
+ }
+
+ @Test
+ public void testSaslAuth() throws Throwable {
+ BaseMessageHandler rpcHandler = mock(BaseMessageHandler.class);
+ doAnswer(
+ invocation -> {
+ RequestMessage message = (RequestMessage)
invocation.getArguments()[1];
+ RpcResponseCallback cb = (RpcResponseCallback)
invocation.getArguments()[2];
+ assertEquals("Ping",
JavaUtils.bytesToString(message.body().nioByteBuffer()));
+ cb.onSuccess(JavaUtils.stringToBytes("Pong"));
+ return null;
+ })
+ .when(rpcHandler)
+ .receive(
+ any(TransportClient.class), any(RequestMessage.class),
any(RpcResponseCallback.class));
+
+ doReturn(true).when(rpcHandler).checkRegistered();
+
+ try (SaslTestCtx ctx = new SaslTestCtx(rpcHandler)) {
+ ByteBuffer response =
+ ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
TimeUnit.SECONDS.toMillis(10));
+ assertEquals("Pong", JavaUtils.bytesToString(response));
+ } finally {
+ // There should be 2 terminated events; one for the client, one for the
server.
+ Throwable error = null;
+ long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10,
TimeUnit.SECONDS);
+ while (deadline > System.nanoTime()) {
+ try {
+ verify(rpcHandler,
times(2)).channelInactive(any(TransportClient.class));
+ error = null;
+ break;
+ } catch (Throwable t) {
+ error = t;
+ TimeUnit.MILLISECONDS.sleep(10);
+ }
+ }
+ if (error != null) {
+ throw error;
+ }
+ }
+ }
+
+ @Test
+ public void testRpcHandlerDelegate() {
+ // Tests all delegates exception for receive(), which is more complicated
and already handled
+ // by all other tests.
+ BaseMessageHandler handler = mock(BaseMessageHandler.class);
+ BaseMessageHandler saslHandler = new SaslRpcHandler(null, null, handler,
null);
+
+ saslHandler.channelInactive(null);
+ verify(handler).channelInactive(isNull());
+
+ saslHandler.exceptionCaught(null, null);
+ verify(handler).exceptionCaught(isNull(), isNull());
+ }
+
+ private static class SaslTestCtx implements AutoCloseable {
+
+ final TransportClient client;
+ final TransportServer server;
+ final TransportContext ctx;
+
+ SaslTestCtx(BaseMessageHandler rpcHandler) throws Exception {
+ TransportConf conf = new TransportConf("shuffle", new CelebornConf());
+
+ this.ctx = new TransportContext(conf, rpcHandler);
+ this.server =
+ ctx.createServer(
+ Collections.singletonList(
+ new SaslServerBootstrap(conf,
SecretRegistryImpl.getInstance())));
+ List<TransportClientBootstrap> clientBootstraps = new ArrayList<>();
+ clientBootstraps.add(
+ new SaslClientBootstrap(conf, "appId", new
SaslCredentials(TEST_USER, TEST_SECRET)));
+ try {
+ this.client =
+ ctx.createClientFactory(clientBootstraps)
+ .createClient(JavaUtils.getLocalHost(), server.getPort());
+ } catch (Exception e) {
+ close();
+ throw e;
+ }
+ }
+
+ public void close() {
+ if (client != null) {
+ client.close();
+ }
+ if (server != null) {
+ server.close();
+ }
+ }
+ }
+}