This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new b09febdd8 [CELEBORN-1176] Server side support for Sasl Auth
b09febdd8 is described below

commit b09febdd8c5f97e370c4cc46c776f1a9bc61e3c5
Author: Chandni Singh <[email protected]>
AuthorDate: Mon Dec 18 11:27:28 2023 +0800

    [CELEBORN-1176] Server side support for Sasl Auth
    
    ### What changes were proposed in this pull request?
    
    This adds the server side Sasl authentication support in the transport 
layer. Most of this code is taken from Apache Spark.
    
    ### Why are the changes needed?
    
    The changes are needed for adding authentication to Celeborn. See 
[CELEBORN-1011](https://issues.apache.org/jira/browse/CELEBORN-1011).
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added UTs.
    
    Closes #2164 from otterc/CELEBORN-1176.
    
    Authored-by: Chandni Singh <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../celeborn/common/network/TransportContext.java  |   4 +
 .../common/network/sasl/CelebornSaslServer.java    | 159 ++++++++++++++++
 .../common/network/sasl/SaslRpcHandler.java        | 131 +++++++++++++
 .../common/network/sasl/SaslServerBootstrap.java   |  49 +++++
 .../celeborn/common/network/sasl/SaslUtils.java    |  13 ++
 .../common/network/sasl/SecretRegistry.java        |  27 +++
 .../common/network/sasl/SecretRegistryImpl.java    |  53 ++++++
 .../network/server/AbstractAuthRpcHandler.java     |  89 +++++++++
 .../common/network/sasl/CelebornSaslSuiteJ.java    | 207 +++++++++++++++++++++
 9 files changed, 732 insertions(+)

diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java 
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
index c8796ea32..86ec779f2 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
@@ -115,6 +115,10 @@ public class TransportContext {
     return new TransportServer(this, host, port, source, msgHandler, 
bootstraps);
   }
 
+  public TransportServer createServer(List<TransportServerBootstrap> 
bootstraps) {
+    return createServer(null, 0, bootstraps);
+  }
+
   public TransportServer createServer(int port) {
     return createServer(null, port, Collections.emptyList());
   }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/CelebornSaslServer.java
 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/CelebornSaslServer.java
new file mode 100644
index 000000000..7248a3389
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/CelebornSaslServer.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.sasl;
+
+import static org.apache.celeborn.common.network.sasl.SaslUtils.*;
+
+import java.util.Map;
+
+import javax.annotation.Nullable;
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.AuthorizeCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+
+import com.google.common.base.Preconditions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A SASL Server for Celeborn which simply keeps track of the state of a 
single SASL session, from
+ * the initial state to the "authenticated" state. (It is not a server in the 
sense of accepting
+ * connections on some socket.)
+ */
+public class CelebornSaslServer {
+  private static final Logger logger = 
LoggerFactory.getLogger(CelebornSaslServer.class);
+
+  private SaslServer saslServer;
+
+  public CelebornSaslServer(
+      String saslMechanism,
+      @Nullable Map<String, String> saslProps,
+      @Nullable CallbackHandler callbackHandler) {
+    Preconditions.checkNotNull(saslMechanism);
+    try {
+      this.saslServer =
+          Sasl.createSaslServer(saslMechanism, null, DEFAULT_REALM, saslProps, 
callbackHandler);
+    } catch (SaslException e) {
+      throw new IllegalArgumentException(e);
+    }
+  }
+
+  /** Determines whether the authentication exchange has completed 
successfully. */
+  public synchronized boolean isComplete() {
+    return saslServer != null && saslServer.isComplete();
+  }
+
+  /** Returns the value of a negotiated property. */
+  public synchronized Object getNegotiatedProperty(String name) {
+    return saslServer.getNegotiatedProperty(name);
+  }
+
+  /**
+   * Used to respond to server SASL tokens.
+   *
+   * @param token Server's SASL token
+   * @return response to send back to the server.
+   */
+  public synchronized byte[] response(byte[] token) {
+    try {
+      return saslServer != null ? saslServer.evaluateResponse(token) : 
EMPTY_BYTE_ARRAY;
+    } catch (SaslException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  /**
+   * Disposes of any system resources or security-sensitive information the 
SaslServer might be
+   * using.
+   */
+  public synchronized void dispose() {
+    if (saslServer != null) {
+      try {
+        saslServer.dispose();
+      } catch (SaslException e) {
+        // ignore
+      } finally {
+        saslServer = null;
+      }
+    }
+  }
+
+  /**
+   * Implementation of javax.security.auth.callback.CallbackHandler for SASL 
DIGEST-MD5 mechanism.
+   */
+  static class DigestCallbackHandler implements CallbackHandler {
+    private final SecretRegistry secretKeyHolder;
+
+    /**
+     * The use of 'volatile' is not necessary here because the 'handle' 
invocation includes both the
+     * NameCallback and PasswordCallback (with the name preceding the 
password), all within the same
+     * thread.
+     */
+    private String userName = null;
+
+    DigestCallbackHandler(SecretRegistry secretRegistry) {
+      this.secretKeyHolder = Preconditions.checkNotNull(secretRegistry);
+    }
+
+    @Override
+    public void handle(Callback[] callbacks) throws 
UnsupportedCallbackException, SaslException {
+      for (Callback callback : callbacks) {
+        if (callback instanceof NameCallback) {
+          logger.trace("SASL server callback: setting username");
+          NameCallback nc = (NameCallback) callback;
+          String encodedName = nc.getName() != null ? nc.getName() : 
nc.getDefaultName();
+          if (encodedName == null) {
+            throw new SaslException("No username provided by client");
+          }
+          userName = decodeIdentifier(encodedName);
+        } else if (callback instanceof PasswordCallback) {
+          logger.trace("SASL server callback: setting password");
+          PasswordCallback pc = (PasswordCallback) callback;
+          String secret = secretKeyHolder.getSecretKey(userName);
+          if (secret == null) {
+            // TODO: CELEBORN-1179 Add support for fetching the secret from 
the Celeborn master.
+            throw new RuntimeException("Registration information not found for 
" + userName);
+          }
+          pc.setPassword(encodePassword(secret));
+        } else if (callback instanceof RealmCallback) {
+          logger.trace("SASL server callback: setting realm");
+          RealmCallback rc = (RealmCallback) callback;
+          rc.setText(rc.getDefaultText());
+        } else if (callback instanceof AuthorizeCallback) {
+          AuthorizeCallback ac = (AuthorizeCallback) callback;
+          String authId = ac.getAuthenticationID();
+          String authzId = ac.getAuthorizationID();
+          ac.setAuthorized(authId.equals(authzId));
+          if (ac.isAuthorized()) {
+            ac.setAuthorizedID(authzId);
+          }
+          logger.debug("SASL Authorization complete, authorized set to {}", 
ac.isAuthorized());
+        } else {
+          throw new UnsupportedCallbackException(callback, "Unrecognized 
callback: " + callback);
+        }
+      }
+    }
+  }
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslRpcHandler.java
 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslRpcHandler.java
new file mode 100644
index 000000000..5bc72f0f1
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslRpcHandler.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.sasl;
+
+import static org.apache.celeborn.common.network.sasl.SaslUtils.*;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
+import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.protocol.RpcRequest;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import org.apache.celeborn.common.network.server.AbstractAuthRpcHandler;
+import org.apache.celeborn.common.network.server.BaseMessageHandler;
+import org.apache.celeborn.common.network.util.TransportConf;
+import org.apache.celeborn.common.protocol.PbSaslRequest;
+
+/**
+ * RPC Handler which performs SASL authentication before delegating to a child 
RPC handler. The
+ * delegate will only receive messages if the given connection has been 
successfully authenticated.
+ * A connection may be authenticated at most once.
+ *
+ * <p>Note that the authentication process consists of multiple 
challenge-response pairs, each of
+ * which are individual RPCs.
+ */
+public class SaslRpcHandler extends AbstractAuthRpcHandler {
+  private static final Logger logger = 
LoggerFactory.getLogger(SaslRpcHandler.class);
+
+  /** Transport configuration. */
+  private final TransportConf conf;
+
+  /** The client channel. */
+  private final Channel channel;
+
+  /** Class which provides secret keys which are shared by server and client 
on a per-app basis. */
+  private final SecretRegistry secretRegistry;
+
+  private CelebornSaslServer saslServer;
+
+  public SaslRpcHandler(
+      TransportConf conf,
+      Channel channel,
+      BaseMessageHandler delegate,
+      SecretRegistry secretRegistry) {
+    super(delegate);
+    this.conf = conf;
+    this.channel = channel;
+    this.secretRegistry = secretRegistry;
+    this.saslServer = null;
+  }
+
+  @Override
+  public boolean checkRegistered() {
+    return delegate.checkRegistered();
+  }
+
+  @Override
+  public boolean doAuthChallenge(
+      TransportClient client, RequestMessage message, RpcResponseCallback 
callback) {
+    if (saslServer == null || !saslServer.isComplete()) {
+      RpcRequest rpcRequest = (RpcRequest) message;
+      PbSaslRequest saslMessage;
+      try {
+        TransportMessage pbMsg = 
TransportMessage.fromByteBuffer(message.body().nioByteBuffer());
+        saslMessage = pbMsg.getParsedPayload();
+      } catch (IOException e) {
+        logger.error("Error while parsing Sasl Message with RPC id {}", 
rpcRequest.requestId, e);
+        callback.onFailure(e);
+        return false;
+      }
+      if (saslServer == null) {
+        saslServer =
+            new CelebornSaslServer(
+                DIGEST_MD5,
+                DEFAULT_SASL_SERVER_PROPS,
+                new CelebornSaslServer.DigestCallbackHandler(secretRegistry));
+      }
+      byte[] response = 
saslServer.response(saslMessage.getPayload().toByteArray());
+      callback.onSuccess(ByteBuffer.wrap(response));
+    }
+    if (saslServer.isComplete()) {
+      logger.debug("SASL authentication successful for channel {}", client);
+      complete();
+      return true;
+    }
+    return false;
+  }
+
+  @Override
+  public void channelInactive(TransportClient client) {
+    super.channelInactive(client);
+    cleanup();
+  }
+
+  private void complete() {
+    cleanup();
+  }
+
+  private void cleanup() {
+    if (null != saslServer) {
+      try {
+        saslServer.dispose();
+      } catch (RuntimeException e) {
+        logger.error("Error while disposing SASL server", e);
+      } finally {
+        saslServer = null;
+      }
+    }
+  }
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslServerBootstrap.java
 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslServerBootstrap.java
new file mode 100644
index 000000000..44455f10e
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslServerBootstrap.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.sasl;
+
+import io.netty.channel.Channel;
+
+import org.apache.celeborn.common.network.server.BaseMessageHandler;
+import org.apache.celeborn.common.network.server.TransportServerBootstrap;
+import org.apache.celeborn.common.network.util.TransportConf;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a 
client connects to the
+ * server. This allows customizing the client channel to allow for things such 
as SASL
+ * authentication.
+ */
+public class SaslServerBootstrap implements TransportServerBootstrap {
+
+  private final TransportConf conf;
+  private final SecretRegistry secretKeyHolder;
+
+  public SaslServerBootstrap(TransportConf conf, SecretRegistry 
secretRegistry) {
+    this.conf = conf;
+    this.secretKeyHolder = secretRegistry;
+  }
+
+  /**
+   * Wrap the given application handler in a SaslRpcHandler that will handle 
the initial SASL
+   * negotiation.
+   */
+  @Override
+  public BaseMessageHandler doBootstrap(Channel channel, BaseMessageHandler 
rpcHandler) {
+    return new SaslRpcHandler(conf, channel, rpcHandler, secretKeyHolder);
+  }
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslUtils.java 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslUtils.java
index 5d57057e4..75a20ee37 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslUtils.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SaslUtils.java
@@ -40,12 +40,25 @@ public class SaslUtils {
   static final Map<String, String> DEFAULT_SASL_CLIENT_PROPS =
       ImmutableMap.<String, String>builder().put(Sasl.QOP, QOP_AUTH).build();
 
+  static final Map<String, String> DEFAULT_SASL_SERVER_PROPS =
+      ImmutableMap.<String, String>builder()
+          .put(Sasl.SERVER_AUTH, "true")
+          .put(Sasl.QOP, QOP_AUTH)
+          .build();
+
   /* Encode a byte[] identifier as a Base64-encoded string. */
   static String encodeIdentifier(String identifier) {
     Preconditions.checkNotNull(identifier, "User cannot be null if SASL is 
enabled");
     return 
Base64.getEncoder().encodeToString(identifier.getBytes(StandardCharsets.UTF_8));
   }
 
+  static String decodeIdentifier(String identifier) {
+    Preconditions.checkNotNull(identifier, "User cannot be null if SASL is 
enabled");
+    return new String(
+        
Base64.getDecoder().decode(identifier.getBytes(StandardCharsets.UTF_8)),
+        StandardCharsets.UTF_8);
+  }
+
   /** Encode a password as a base64-encoded char[] array. */
   static char[] encodePassword(String password) {
     Preconditions.checkNotNull(password, "Password cannot be null if SASL is 
enabled");
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SecretRegistry.java
 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SecretRegistry.java
new file mode 100644
index 000000000..995e7af80
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SecretRegistry.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.sasl;
+
+/** Interface for getting a secret key associated with some application. */
+public interface SecretRegistry {
+
+  /** Gets an appropriate SASL secret key for the given appId. */
+  String getSecretKey(String appId);
+
+  boolean isRegistered(String appId);
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/sasl/SecretRegistryImpl.java
 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SecretRegistryImpl.java
new file mode 100644
index 000000000..402fc3a27
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/sasl/SecretRegistryImpl.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.sasl;
+
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * A simple implementation of {@link SecretRegistry} that stores secrets in 
memory. It is designed
+ * as a singleton.
+ */
+public class SecretRegistryImpl implements SecretRegistry {
+
+  private static final SecretRegistryImpl INSTANCE = new SecretRegistryImpl();
+
+  public static SecretRegistryImpl getInstance() {
+    return INSTANCE;
+  }
+
+  private final ConcurrentHashMap<String, String> secrets = new 
ConcurrentHashMap<>();
+
+  public void register(String appId, String secret) {
+    secrets.put(appId, secret);
+  }
+
+  public void unregister(String appId) {
+    secrets.remove(appId);
+  }
+
+  public boolean isRegistered(String appId) {
+    return secrets.containsKey(appId);
+  }
+
+  /** Gets an appropriate SASL secret key for the given appId. */
+  @Override
+  public String getSecretKey(String appId) {
+    return secrets.get(appId);
+  }
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/server/AbstractAuthRpcHandler.java
 
b/common/src/main/java/org/apache/celeborn/common/network/server/AbstractAuthRpcHandler.java
new file mode 100644
index 000000000..36eb89691
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/server/AbstractAuthRpcHandler.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.server;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
+import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+
+/**
+ * RPC Handler which performs authentication, and when it's successful, 
delegates further calls to
+ * another RPC handler. The authentication handshake itself should be 
implemented by subclasses.
+ */
+public abstract class AbstractAuthRpcHandler extends BaseMessageHandler {
+  /** RpcHandler we will delegate to for authenticated connections. */
+  private static final Logger LOG = 
LoggerFactory.getLogger(AbstractAuthRpcHandler.class);
+
+  protected final BaseMessageHandler delegate;
+
+  private boolean isAuthenticated;
+
+  protected AbstractAuthRpcHandler(BaseMessageHandler delegate) {
+    this.delegate = delegate;
+  }
+
+  /**
+   * Responds to an authentication challenge.
+   *
+   * @return Whether the client is authenticated.
+   */
+  protected abstract boolean doAuthChallenge(
+      TransportClient client, RequestMessage message, RpcResponseCallback 
callback);
+
+  @Override
+  public final void receive(
+      TransportClient client, RequestMessage message, RpcResponseCallback 
callback) {
+    if (isAuthenticated) {
+      LOG.trace("Already authenticated. Delegating {}", client.getClientId());
+      delegate.receive(client, message, callback);
+    } else {
+      isAuthenticated = doAuthChallenge(client, message, callback);
+    }
+  }
+
+  @Override
+  public final void receive(TransportClient client, RequestMessage message) {
+    if (isAuthenticated) {
+      delegate.receive(client, message);
+    } else {
+      throw new SecurityException("Unauthenticated call to receive().");
+    }
+  }
+
+  @Override
+  public void channelActive(TransportClient client) {
+    delegate.channelActive(client);
+  }
+
+  @Override
+  public void channelInactive(TransportClient client) {
+    delegate.channelInactive(client);
+  }
+
+  @Override
+  public void exceptionCaught(Throwable cause, TransportClient client) {
+    delegate.exceptionCaught(cause, client);
+  }
+
+  public boolean isAuthenticated() {
+    return isAuthenticated;
+  }
+}
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java
 
b/common/src/test/java/org/apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java
new file mode 100644
index 000000000..582988e1d
--- /dev/null
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/sasl/CelebornSaslSuiteJ.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.sasl;
+
+import static org.apache.celeborn.common.network.sasl.SaslUtils.*;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.TransportContext;
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
+import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.client.TransportClientBootstrap;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.server.BaseMessageHandler;
+import org.apache.celeborn.common.network.server.TransportServer;
+import org.apache.celeborn.common.network.util.TransportConf;
+import org.apache.celeborn.common.util.JavaUtils;
+
+/**
+ * Jointly tests {@link CelebornSaslClient} and {@link CelebornSaslServer}, as 
both are black boxes.
+ */
+public class CelebornSaslSuiteJ {
+  private static final String TEST_USER = "appId";
+  private static final String TEST_SECRET = "secret";
+
+  @BeforeClass
+  public static void setup() {
+    SecretRegistryImpl.getInstance().register(TEST_USER, TEST_SECRET);
+  }
+
+  @Test
+  public void testDigestMatching() {
+    CelebornSaslClient client =
+        new CelebornSaslClient(
+            DIGEST_MD5,
+            DEFAULT_SASL_CLIENT_PROPS,
+            new CelebornSaslClient.ClientCallbackHandler(TEST_USER, 
TEST_SECRET));
+    CelebornSaslServer server =
+        new CelebornSaslServer(
+            DIGEST_MD5,
+            DEFAULT_SASL_SERVER_PROPS,
+            new 
CelebornSaslServer.DigestCallbackHandler(SecretRegistryImpl.getInstance()));
+
+    assertFalse(client.isComplete());
+    assertFalse(server.isComplete());
+
+    byte[] clientMessage = client.firstToken();
+
+    while (!client.isComplete()) {
+      clientMessage = client.response(server.response(clientMessage));
+    }
+    assertTrue(server.isComplete());
+
+    // Disposal should invalidate
+    server.dispose();
+    assertFalse(server.isComplete());
+    client.dispose();
+    assertFalse(client.isComplete());
+  }
+
+  @Test
+  public void testDigestNonMatching() {
+    CelebornSaslClient client =
+        new CelebornSaslClient(
+            DIGEST_MD5,
+            DEFAULT_SASL_CLIENT_PROPS,
+            new CelebornSaslClient.ClientCallbackHandler(TEST_USER, "invalid" 
+ TEST_SECRET));
+    CelebornSaslServer server =
+        new CelebornSaslServer(
+            DIGEST_MD5,
+            DEFAULT_SASL_SERVER_PROPS,
+            new 
CelebornSaslServer.DigestCallbackHandler(SecretRegistryImpl.getInstance()));
+
+    assertFalse(client.isComplete());
+    assertFalse(server.isComplete());
+
+    byte[] clientMessage = client.firstToken();
+
+    try {
+      while (!client.isComplete()) {
+        clientMessage = client.response(server.response(clientMessage));
+      }
+      fail("Should not have completed");
+    } catch (Exception e) {
+      assertTrue(e.getMessage().contains("Mismatched response"));
+      assertFalse(client.isComplete());
+      assertFalse(server.isComplete());
+    }
+  }
+
+  @Test
+  public void testSaslAuth() throws Throwable {
+    BaseMessageHandler rpcHandler = mock(BaseMessageHandler.class);
+    doAnswer(
+            invocation -> {
+              RequestMessage message = (RequestMessage) 
invocation.getArguments()[1];
+              RpcResponseCallback cb = (RpcResponseCallback) 
invocation.getArguments()[2];
+              assertEquals("Ping", 
JavaUtils.bytesToString(message.body().nioByteBuffer()));
+              cb.onSuccess(JavaUtils.stringToBytes("Pong"));
+              return null;
+            })
+        .when(rpcHandler)
+        .receive(
+            any(TransportClient.class), any(RequestMessage.class), 
any(RpcResponseCallback.class));
+
+    doReturn(true).when(rpcHandler).checkRegistered();
+
+    try (SaslTestCtx ctx = new SaslTestCtx(rpcHandler)) {
+      ByteBuffer response =
+          ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 
TimeUnit.SECONDS.toMillis(10));
+      assertEquals("Pong", JavaUtils.bytesToString(response));
+    } finally {
+      // There should be 2 terminated events; one for the client, one for the 
server.
+      Throwable error = null;
+      long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, 
TimeUnit.SECONDS);
+      while (deadline > System.nanoTime()) {
+        try {
+          verify(rpcHandler, 
times(2)).channelInactive(any(TransportClient.class));
+          error = null;
+          break;
+        } catch (Throwable t) {
+          error = t;
+          TimeUnit.MILLISECONDS.sleep(10);
+        }
+      }
+      if (error != null) {
+        throw error;
+      }
+    }
+  }
+
+  @Test
+  public void testRpcHandlerDelegate() {
+    // Tests all delegates exception for receive(), which is more complicated 
and already handled
+    // by all other tests.
+    BaseMessageHandler handler = mock(BaseMessageHandler.class);
+    BaseMessageHandler saslHandler = new SaslRpcHandler(null, null, handler, 
null);
+
+    saslHandler.channelInactive(null);
+    verify(handler).channelInactive(isNull());
+
+    saslHandler.exceptionCaught(null, null);
+    verify(handler).exceptionCaught(isNull(), isNull());
+  }
+
+  private static class SaslTestCtx implements AutoCloseable {
+
+    final TransportClient client;
+    final TransportServer server;
+    final TransportContext ctx;
+
+    SaslTestCtx(BaseMessageHandler rpcHandler) throws Exception {
+      TransportConf conf = new TransportConf("shuffle", new CelebornConf());
+
+      this.ctx = new TransportContext(conf, rpcHandler);
+      this.server =
+          ctx.createServer(
+              Collections.singletonList(
+                  new SaslServerBootstrap(conf, 
SecretRegistryImpl.getInstance())));
+      List<TransportClientBootstrap> clientBootstraps = new ArrayList<>();
+      clientBootstraps.add(
+          new SaslClientBootstrap(conf, "appId", new 
SaslCredentials(TEST_USER, TEST_SECRET)));
+      try {
+        this.client =
+            ctx.createClientFactory(clientBootstraps)
+                .createClient(JavaUtils.getLocalHost(), server.getPort());
+      } catch (Exception e) {
+        close();
+        throw e;
+      }
+    }
+
+    public void close() {
+      if (client != null) {
+        client.close();
+      }
+      if (server != null) {
+        server.close();
+      }
+    }
+  }
+}

Reply via email to