This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 61b7d44 Apply appropriate RPC handler to receive, receiveStream when auth enabled 61b7d44 is described below commit 61b7d446b37cecc45e6d274bbfdde3b745bf068f Author: Sean Owen <sro...@gmail.com> AuthorDate: Fri Apr 17 13:25:12 2020 -0500 Apply appropriate RPC handler to receive, receiveStream when auth enabled --- .../spark/network/crypto/AuthRpcHandler.java | 73 +++----------- .../apache/spark/network/sasl/SaslRpcHandler.java | 60 +++--------- .../network/server/AbstractAuthRpcHandler.java | 107 +++++++++++++++++++++ .../spark/network/crypto/AuthIntegrationSuite.java | 12 +-- .../apache/spark/network/sasl/SparkSaslSuite.java | 3 +- 5 files changed, 142 insertions(+), 113 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 821cc7a..dd31c95 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -29,12 +29,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.server.AbstractAuthRpcHandler; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.TransportConf; /** @@ -46,7 +45,7 @@ import org.apache.spark.network.util.TransportConf; * The delegate will only receive messages if the given connection has been successfully * authenticated. A connection may be authenticated at most once. */ -class AuthRpcHandler extends RpcHandler { +class AuthRpcHandler extends AbstractAuthRpcHandler { private static final Logger LOG = LoggerFactory.getLogger(AuthRpcHandler.class); /** Transport configuration. */ @@ -55,36 +54,31 @@ class AuthRpcHandler extends RpcHandler { /** The client channel. */ private final Channel channel; - /** - * RpcHandler we will delegate to for authenticated connections. When falling back to SASL - * this will be replaced with the SASL RPC handler. - */ - @VisibleForTesting - RpcHandler delegate; - /** Class which provides secret keys which are shared by server and client on a per-app basis. */ private final SecretKeyHolder secretKeyHolder; - /** Whether auth is done and future calls should be delegated. */ + /** RPC handler for auth handshake when falling back to SASL auth. */ @VisibleForTesting - boolean doDelegate; + SaslRpcHandler saslHandler; AuthRpcHandler( TransportConf conf, Channel channel, RpcHandler delegate, SecretKeyHolder secretKeyHolder) { + super(delegate); this.conf = conf; this.channel = channel; - this.delegate = delegate; this.secretKeyHolder = secretKeyHolder; } @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { - if (doDelegate) { - delegate.receive(client, message, callback); - return; + protected boolean doAuthChallenge( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + if (saslHandler != null) { + return saslHandler.doAuthChallenge(client, message, callback); } int position = message.position(); @@ -98,18 +92,17 @@ class AuthRpcHandler extends RpcHandler { if (conf.saslFallback()) { LOG.warn("Failed to parse new auth challenge, reverting to SASL for client {}.", channel.remoteAddress()); - delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder); + saslHandler = new SaslRpcHandler(conf, channel, null, secretKeyHolder); message.position(position); message.limit(limit); - delegate.receive(client, message, callback); - doDelegate = true; + return saslHandler.doAuthChallenge(client, message, callback); } else { LOG.debug("Unexpected challenge message from client {}, closing channel.", channel.remoteAddress()); callback.onFailure(new IllegalArgumentException("Unknown challenge message.")); channel.close(); } - return; + return false; } // Here we have the client challenge, so perform the new auth protocol and set up the channel. @@ -131,7 +124,7 @@ class AuthRpcHandler extends RpcHandler { LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress()); callback.onFailure(new IllegalArgumentException("Authentication failed.")); channel.close(); - return; + return false; } finally { if (engine != null) { try { @@ -143,40 +136,6 @@ class AuthRpcHandler extends RpcHandler { } LOG.debug("Authorization successful for client {}.", channel.remoteAddress()); - doDelegate = true; - } - - @Override - public void receive(TransportClient client, ByteBuffer message) { - delegate.receive(client, message); - } - - @Override - public StreamCallbackWithID receiveStream( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - return delegate.receiveStream(client, message, callback); + return true; } - - @Override - public StreamManager getStreamManager() { - return delegate.getStreamManager(); - } - - @Override - public void channelActive(TransportClient client) { - delegate.channelActive(client); - } - - @Override - public void channelInactive(TransportClient client) { - delegate.channelInactive(client); - } - - @Override - public void exceptionCaught(Throwable cause, TransportClient client) { - delegate.exceptionCaught(cause, client); - } - } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 355a3de..cc9e88f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -28,10 +28,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.AbstractAuthRpcHandler; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; @@ -43,7 +42,7 @@ import org.apache.spark.network.util.TransportConf; * Note that the authentication process consists of multiple challenge-response pairs, each of * which are individual RPCs. */ -public class SaslRpcHandler extends RpcHandler { +public class SaslRpcHandler extends AbstractAuthRpcHandler { private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); /** Transport configuration. */ @@ -52,37 +51,28 @@ public class SaslRpcHandler extends RpcHandler { /** The client channel. */ private final Channel channel; - /** RpcHandler we will delegate to for authenticated connections. */ - private final RpcHandler delegate; - /** Class which provides secret keys which are shared by server and client on a per-app basis. */ private final SecretKeyHolder secretKeyHolder; private SparkSaslServer saslServer; - private boolean isComplete; - private boolean isAuthenticated; public SaslRpcHandler( TransportConf conf, Channel channel, RpcHandler delegate, SecretKeyHolder secretKeyHolder) { + super(delegate); this.conf = conf; this.channel = channel; - this.delegate = delegate; this.secretKeyHolder = secretKeyHolder; this.saslServer = null; - this.isComplete = false; - this.isAuthenticated = false; } @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { - if (isComplete) { - // Authentication complete, delegate to base handler. - delegate.receive(client, message, callback); - return; - } + public boolean doAuthChallenge( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { if (saslServer == null || !saslServer.isComplete()) { ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); SaslMessage saslMessage; @@ -118,43 +108,21 @@ public class SaslRpcHandler extends RpcHandler { if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { logger.debug("SASL authentication successful for channel {}", client); complete(true); - return; + return true; } logger.debug("Enabling encryption for channel {}", client); SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); complete(false); - return; + return true; } - } - - @Override - public void receive(TransportClient client, ByteBuffer message) { - delegate.receive(client, message); - } - - @Override - public StreamCallbackWithID receiveStream( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - return delegate.receiveStream(client, message, callback); - } - - @Override - public StreamManager getStreamManager() { - return delegate.getStreamManager(); - } - - @Override - public void channelActive(TransportClient client) { - delegate.channelActive(client); + return false; } @Override public void channelInactive(TransportClient client) { try { - delegate.channelInactive(client); + super.channelInactive(client); } finally { if (saslServer != null) { saslServer.dispose(); @@ -162,11 +130,6 @@ public class SaslRpcHandler extends RpcHandler { } } - @Override - public void exceptionCaught(Throwable cause, TransportClient client) { - delegate.exceptionCaught(cause, client); - } - private void complete(boolean dispose) { if (dispose) { try { @@ -177,7 +140,6 @@ public class SaslRpcHandler extends RpcHandler { } saslServer = null; - isComplete = true; } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java new file mode 100644 index 0000000..92eb886 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.nio.ByteBuffer; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; +import org.apache.spark.network.client.TransportClient; + +/** + * RPC Handler which performs authentication, and when it's successful, delegates further + * calls to another RPC handler. The authentication handshake itself should be implemented + * by subclasses. + */ +public abstract class AbstractAuthRpcHandler extends RpcHandler { + /** RpcHandler we will delegate to for authenticated connections. */ + private final RpcHandler delegate; + + private boolean isAuthenticated; + + protected AbstractAuthRpcHandler(RpcHandler delegate) { + this.delegate = delegate; + } + + /** + * Responds to an authentication challenge. + * + * @return Whether the client is authenticated. + */ + protected abstract boolean doAuthChallenge( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback); + + @Override + public final void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + if (isAuthenticated) { + delegate.receive(client, message, callback); + } else { + isAuthenticated = doAuthChallenge(client, message, callback); + } + } + + @Override + public final void receive(TransportClient client, ByteBuffer message) { + if (isAuthenticated) { + delegate.receive(client, message); + } else { + throw new SecurityException("Unauthenticated call to receive()."); + } + } + + @Override + public final StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + if (isAuthenticated) { + return delegate.receiveStream(client, message, callback); + } else { + throw new SecurityException("Unauthenticated call to receiveStream()."); + } + } + + @Override + public StreamManager getStreamManager() { + return delegate.getStreamManager(); + } + + @Override + public void channelActive(TransportClient client) { + delegate.channelActive(client); + } + + @Override + public void channelInactive(TransportClient client) { + delegate.channelInactive(client); + } + + @Override + public void exceptionCaught(Throwable cause, TransportClient client) { + delegate.exceptionCaught(cause, client); + } + + public boolean isAuthenticated() { + return isAuthenticated; + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java index 2f9dd62..a87a6aa 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -34,7 +34,6 @@ import org.apache.spark.network.TransportContext; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.sasl.SaslRpcHandler; import org.apache.spark.network.sasl.SaslServerBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.RpcHandler; @@ -65,8 +64,7 @@ public class AuthIntegrationSuite { ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); assertEquals("Pong", JavaUtils.bytesToString(reply)); - assertTrue(ctx.authRpcHandler.doDelegate); - assertFalse(ctx.authRpcHandler.delegate instanceof SaslRpcHandler); + assertNull(ctx.authRpcHandler.saslHandler); } @Test @@ -78,7 +76,7 @@ public class AuthIntegrationSuite { ctx.createClient("client"); fail("Should have failed to create client."); } catch (Exception e) { - assertFalse(ctx.authRpcHandler.doDelegate); + assertFalse(ctx.authRpcHandler.isAuthenticated()); assertFalse(ctx.serverChannel.isActive()); } } @@ -91,6 +89,8 @@ public class AuthIntegrationSuite { ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); assertEquals("Pong", JavaUtils.bytesToString(reply)); + assertNotNull(ctx.authRpcHandler.saslHandler); + assertTrue(ctx.authRpcHandler.isAuthenticated()); } @Test @@ -120,7 +120,7 @@ public class AuthIntegrationSuite { ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); fail("Should have failed unencrypted RPC."); } catch (Exception e) { - assertTrue(ctx.authRpcHandler.doDelegate); + assertTrue(ctx.authRpcHandler.isAuthenticated()); } } @@ -151,7 +151,7 @@ public class AuthIntegrationSuite { ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); fail("Should have failed unencrypted RPC."); } catch (Exception e) { - assertTrue(ctx.authRpcHandler.doDelegate); + assertTrue(ctx.authRpcHandler.isAuthenticated()); assertTrue(e.getMessage() + " is not an expected error", e.getMessage().contains("DDDDD")); // Verify we receive the complete error message int messageStart = e.getMessage().indexOf("DDDDD"); diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index cf2d72f..ecaeec9 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -357,7 +357,8 @@ public class SparkSaslSuite { public void testDelegates() throws Exception { Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods(); for (Method m : rpcHandlerMethods) { - SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes()); + Method delegate = SaslRpcHandler.class.getMethod(m.getName(), m.getParameterTypes()); + assertNotEquals(delegate.getDeclaringClass(), RpcHandler.class); } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org