Repository: spark Updated Branches: refs/heads/master f90ad5d42 -> 5e73138a0
http://git-wip-us.apache.org/repos/asf/spark/blob/5e73138a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java ---------------------------------------------------------------------- diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java new file mode 100644 index 0000000..2c0ce40 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import java.io.IOException; +import java.util.Map; + +import com.google.common.base.Charsets; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.BaseEncoding; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. (It is not a server in the sense of accepting + * connections on some socket.) + */ +public class SparkSaslServer { + private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class); + + /** + * This is passed as the server name when creating the sasl client/server. + * This could be changed to be configurable in the future. + */ + static final String DEFAULT_REALM = "default"; + + /** + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + static final String DIGEST = "DIGEST-MD5"; + + /** + * The quality of protection is just "auth". This means that we are doing + * authentication only, we are not supporting integrity or privacy protection of the + * communication channel after authentication. This could be changed to be configurable + * in the future. + */ + static final Map<String, String> SASL_PROPS = ImmutableMap.<String, String>builder() + .put(Sasl.QOP, "auth") + .put(Sasl.SERVER_AUTH, "true") + .build(); + + /** Identifier for a certain secret key within the secretKeyHolder. */ + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslServer saslServer; + + public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS, + new DigestCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Determines whether the authentication exchange has completed successfully. + */ + public synchronized boolean isComplete() { + return saslServer != null && saslServer.isComplete(); + } + + /** + * Used to respond to server SASL tokens. + * @param token Server's SASL token + * @return response to send back to the server. + */ + public synchronized byte[] response(byte[] token) { + try { + return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslServer might be using. + */ + public synchronized void dispose() { + if (saslServer != null) { + try { + saslServer.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslServer = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism. + */ + private class DigestCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL server callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL server callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL server callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof AuthorizeCallback) { + AuthorizeCallback ac = (AuthorizeCallback) callback; + String authId = ac.getAuthenticationID(); + String authzId = ac.getAuthorizationID(); + ac.setAuthorized(authId.equals(authzId)); + if (ac.isAuthorized()) { + ac.setAuthorizedID(authzId); + } + logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized()); + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } + + /* Encode a byte[] identifier as a Base64-encoded string. */ + public static String encodeIdentifier(String identifier) { + Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); + return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8)); + } + + /** Encode a password as a base64-encoded char[] array. */ + public static char[] encodePassword(String password) { + Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); + return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray(); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/5e73138a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java ---------------------------------------------------------------------- diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index a9dff31..cd3fea8 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -41,7 +41,7 @@ import org.apache.spark.network.util.JavaUtils; * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark- * level shuffle block. */ -public class ExternalShuffleBlockHandler implements RpcHandler { +public class ExternalShuffleBlockHandler extends RpcHandler { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); private final ExternalShuffleBlockManager blockManager; http://git-wip-us.apache.org/repos/asf/spark/blob/5e73138a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java ---------------------------------------------------------------------- diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 6bbabc4..b0b19ba 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,8 +17,6 @@ package org.apache.spark.network.shuffle; -import java.io.Closeable; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,15 +34,20 @@ import org.apache.spark.network.util.TransportConf; * BlockTransferService), which has the downside of losing the shuffle data if we lose the * executors. */ -public class ExternalShuffleClient implements ShuffleClient { +public class ExternalShuffleClient extends ShuffleClient { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); private final TransportClientFactory clientFactory; - private final String appId; - public ExternalShuffleClient(TransportConf conf, String appId) { + private String appId; + + public ExternalShuffleClient(TransportConf conf) { TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); this.clientFactory = context.createClientFactory(); + } + + @Override + public void init(String appId) { this.appId = appId; } @@ -55,6 +58,7 @@ public class ExternalShuffleClient implements ShuffleClient { String execId, String[] blockIds, BlockFetchingListener listener) { + assert appId != null : "Called before init()"; logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { TransportClient client = clientFactory.createClient(host, port); @@ -82,6 +86,7 @@ public class ExternalShuffleClient implements ShuffleClient { int port, String execId, ExecutorShuffleInfo executorInfo) { + assert appId != null : "Called before init()"; TransportClient client = clientFactory.createClient(host, port); byte[] registerExecutorMessage = JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo)); http://git-wip-us.apache.org/repos/asf/spark/blob/5e73138a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java ---------------------------------------------------------------------- diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index d46a562..f72ab40 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -20,7 +20,14 @@ package org.apache.spark.network.shuffle; import java.io.Closeable; /** Provides an interface for reading shuffle files, either from an Executor or external service. */ -public interface ShuffleClient extends Closeable { +public abstract class ShuffleClient implements Closeable { + + /** + * Initializes the ShuffleClient, specifying this Executor's appId. + * Must be called before any other method on the ShuffleClient. + */ + public void init(String appId) { } + /** * Fetch a sequence of blocks from a remote node asynchronously, * @@ -28,7 +35,7 @@ public interface ShuffleClient extends Closeable { * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. */ - public void fetchBlocks( + public abstract void fetchBlocks( String host, int port, String execId, http://git-wip-us.apache.org/repos/asf/spark/blob/5e73138a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java ---------------------------------------------------------------------- diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java new file mode 100644 index 0000000..8478120 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.io.IOException; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class SaslIntegrationSuite { + static ExternalShuffleBlockHandler handler; + static TransportServer server; + static TransportConf conf; + static TransportContext context; + + TransportClientFactory clientFactory; + + /** Provides a secret key holder which always returns the given secret key. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + + private final String secretKey; + + TestSecretKeyHolder(String secretKey) { + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + @Override + public String getSecretKey(String appId) { + return secretKey; + } + } + + + @BeforeClass + public static void beforeAll() throws IOException { + SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key"); + SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder); + conf = new TransportConf(new SystemPropertyConfigProvider()); + context = new TransportContext(conf, handler); + server = context.createServer(); + } + + + @AfterClass + public static void afterAll() { + server.close(); + } + + @After + public void afterEach() { + if (clientFactory != null) { + clientFactory.close(); + clientFactory = null; + } + } + + @Test + public void testGoodClient() { + clientFactory = context.createClientFactory( + Lists.<TransportClientBootstrap>newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + String msg = "Hello, World!"; + byte[] resp = client.sendRpcSync(msg.getBytes(), 1000); + assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg + } + + @Test + public void testBadClient() { + clientFactory = context.createClientFactory( + Lists.<TransportClientBootstrap>newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key")))); + + try { + // Bootstrap should fail on startup. + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + @Test + public void testNoSaslClient() { + clientFactory = context.createClientFactory( + Lists.<TransportClientBootstrap>newArrayList()); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.sendRpcSync(new byte[13], 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); + } + + try { + // Guessing the right tag byte doesn't magically get you in... + client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); + } + } + + @Test + public void testNoSaslServer() { + RpcHandler handler = new TestRpcHandler(); + TransportContext context = new TransportContext(conf, handler); + clientFactory = context.createClientFactory( + Lists.<TransportClientBootstrap>newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key")))); + TransportServer server = context.createServer(); + try { + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); + } finally { + server.close(); + } + } + + /** RPC handler which simply responds with the message it received. */ + public static class TestRpcHandler extends RpcHandler { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + callback.onSuccess(message); + } + + @Override + public StreamManager getStreamManager() { + return new OneForOneStreamManager(); + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/5e73138a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java ---------------------------------------------------------------------- diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java new file mode 100644 index 0000000..67a07f3 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. + */ +public class SparkSaslSuite { + + /** Provides a secret key holder which returns secret key == appId */ + private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() { + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + return appId; + } + }; + + @Test + public void testMatching() { + SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + assertTrue(server.isComplete()); + + // Disposal should invalidate + server.dispose(); + assertFalse(server.isComplete()); + client.dispose(); + assertFalse(client.isComplete()); + } + + + @Test + public void testNonMatching() { + SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + try { + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + fail("Should not have completed"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("Mismatched response")); + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/5e73138a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java ---------------------------------------------------------------------- diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index b3bcf5f..bc101f5 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -135,7 +135,8 @@ public class ExternalShuffleIntegrationSuite { final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + ExternalShuffleClient client = new ExternalShuffleClient(conf); + client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @Override @@ -164,6 +165,7 @@ public class ExternalShuffleIntegrationSuite { if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); } + client.close(); return res; } @@ -265,7 +267,8 @@ public class ExternalShuffleIntegrationSuite { } private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { - ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + ExternalShuffleClient client = new ExternalShuffleClient(conf); + client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); } http://git-wip-us.apache.org/repos/asf/spark/blob/5e73138a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala ---------------------------------------------------------------------- diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index ad1a6f0..0f27f55 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -74,6 +74,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr)) + blockManager.initialize("app-id") tempDirectory = Files.createTempDir() manualClock.setTime(0) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
